File size: 2,200 Bytes
3e29f25
 
54fc0e9
 
fed31c5
3e29f25
 
fe301b8
176fed0
 
 
5d5c62c
2648c95
029e7b1
 
176fed0
 
 
 
2648c95
176fed0
98de1e2
176fed0
 
 
 
98de1e2
176fed0
3e29f25
 
98ea8d4
 
3e29f25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e18407
3e29f25
 
 
 
 
 
 
b22f7d7
 
 
 
 
 
029e7b1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
---
license: apache-2.0
base_model:
- ruwwww/waifu_diffusion
pipeline_tag: text-to-image
---

## JiT-B

Read the [blog](https://ruwwww.github.io/al-folio/blog/2026/waifu-diffusion/) about the base model.

![](images/scale_wise.png)

*Scales are normalized around the center of the image*

## What's new

- Flux2 RoPE calculation for both the image and text sequences
- [Text embeddings](https://huggingface.co/nebulette/booru-character-aware-tokenizer) as a condition
- [Klein-like](https://huggingface.co/nebulette/klein-x), bias-free modules and QK norm, shift-free AdaLN
- Initialized from the [base model](https://huggingface.co/ruwwww/waifu_diffusion), kept the CIELAB output
- Masked image training

## Future

- Sequence separated spatial separation built into the tokenizer
- Spatial tokenizer, leave the concept bleeding behind

## Training

Despite its size, the model took hours to bake instead of [days](https://huggingface.co/kaupane/DiT-Wikiart-Base).

```python
noisy = noise * (1 - t) + pixel_values * t
v_pred = model.forward(noisy, t, ctx)
v_target = pixel_values - noise
loss = torch.nn.functional.mse_loss(v_pred, v_target)
```

## Inference

```python
@torch.no_grad()
def inference(model: DiT, device=None, steps=50):
    tokenizer = AutoTokenizer.from_pretrained('nebulette/booru-character-aware-tokenizer')
    ctx = torch.tensor(tokenizer.encode('portrait')).unsqueeze(0).to(device)
    xt = torch.randn((1, 3, 48, 48), device=device)

    # Generate time steps from 0 to 1.
    time_steps = torch.linspace(0.0, 1.0, steps + 1, device=device)

    # Iterate through time steps.
    for t in time_steps:
        t = t.unsqueeze(0)
        # Predict the velocity at point (x_t, t) using the model.
        v_pred = model.forward(xt, t, ctx)

        # Update the state based on the predicted velocity.
        xt = xt + v_pred * (1 / steps)

    # Convert CIELAB → RGB.
    lab = torch.clamp(xt[0], -1, 1).cpu().numpy()
    L = (lab[0] + 1) * 50
    a = lab[1] * 128
    b = lab[2] * 128
    rgb = color.lab2rgb(np.stack([L, a, b], axis=-1)) * 255.0

    return Image.fromarray(rgb.astype(np.uint8))
```

## References

- 3588028.3603685
- 2503.16397
- 2508.02324 (Figure 8)