nebulette commited on
Commit
3e29f25
·
verified ·
1 Parent(s): c5cd51e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +43 -3
README.md CHANGED
@@ -1,3 +1,43 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ ## Training
6
+
7
+ ```python
8
+ noisy = noise * (1 - t) + pixel_values * t
9
+ v_pred = model.forward(noisy, t, ctx)
10
+ v_target = pixel_values - noise
11
+ loss = torch.nn.functional.mse_loss(v_pred, v_target)
12
+ ```
13
+
14
+ ## Inference
15
+
16
+ ```python
17
+ @torch.no_grad()
18
+ def inference(model: DiT, device=None, steps=50):
19
+ tokenizer = AutoTokenizer.from_pretrained('nebulette/booru-character-aware-tokenizer')
20
+ ctx = torch.tensor(tokenizer.encode('portrait')).unsqueeze(0).to(device)
21
+ xt = torch.randn((1, 3, 48, 48), device=device)
22
+
23
+ # Generate time steps from 0 to 1.
24
+ time_steps = torch.linspace(0.0, 1.0, steps + 1, device=device)
25
+
26
+ # Iterate through time steps.
27
+ for t in time_steps:
28
+ t = t.unsqueeze(0)
29
+ # Predict the velocity at point (x_t, t) using the model.
30
+ v_pred = model.forward(xt, t, ctx)
31
+
32
+ # Update the state based on the predicted velocity.
33
+ xt = xt + v_pred * (1 / steps)
34
+
35
+ # Convert CIELAB → RGB
36
+ lab = torch.clamp(xt[0], -1, 1).cpu().numpy()
37
+ L = (lab[0] + 1) * 50
38
+ a = lab[1] * 128
39
+ b = lab[2] * 128
40
+ rgb = color.lab2rgb(np.stack([L, a, b], axis=-1)) * 255.0
41
+
42
+ return Image.fromarray(rgb.astype(np.uint8))
43
+ ```