nebulette's picture
Update README.md
fe301b8 verified
|
raw
history blame
2.2 kB
---
license: apache-2.0
base_model:
- ruwwww/waifu_diffusion
pipeline_tag: text-to-image
---
## JiT-B
Read the [blog](https://ruwwww.github.io/al-folio/blog/2026/waifu-diffusion/) about the base model.
![](images/scale_wise.png)
*Scales are normalized around the center of the image*
## What's new
- Flux2 RoPE calculation for both the image and text sequences
- [Text embeddings](https://huggingface.co/nebulette/booru-character-aware-tokenizer) as a condition
- [Klein-like](https://huggingface.co/nebulette/klein-x), bias-free modules and QK norm, shift-free AdaLN
- Initialized from the [base model](https://huggingface.co/ruwwww/waifu_diffusion), kept the CIELAB output
- Masked image training
## Future
- Sequence separated spatial separation built into the tokenizer
- Spatial tokenizer, leave the concept bleeding behind
## Training
Despite its size, the model took hours to bake instead of [days](https://huggingface.co/kaupane/DiT-Wikiart-Base).
```python
noisy = noise * (1 - t) + pixel_values * t
v_pred = model.forward(noisy, t, ctx)
v_target = pixel_values - noise
loss = torch.nn.functional.mse_loss(v_pred, v_target)
```
## Inference
```python
@torch.no_grad()
def inference(model: DiT, device=None, steps=50):
tokenizer = AutoTokenizer.from_pretrained('nebulette/booru-character-aware-tokenizer')
ctx = torch.tensor(tokenizer.encode('portrait')).unsqueeze(0).to(device)
xt = torch.randn((1, 3, 48, 48), device=device)
# Generate time steps from 0 to 1.
time_steps = torch.linspace(0.0, 1.0, steps + 1, device=device)
# Iterate through time steps.
for t in time_steps:
t = t.unsqueeze(0)
# Predict the velocity at point (x_t, t) using the model.
v_pred = model.forward(xt, t, ctx)
# Update the state based on the predicted velocity.
xt = xt + v_pred * (1 / steps)
# Convert CIELAB → RGB.
lab = torch.clamp(xt[0], -1, 1).cpu().numpy()
L = (lab[0] + 1) * 50
a = lab[1] * 128
b = lab[2] * 128
rgb = color.lab2rgb(np.stack([L, a, b], axis=-1)) * 255.0
return Image.fromarray(rgb.astype(np.uint8))
```
## References
- 3588028.3603685
- 2503.16397
- 2508.02324 (Figure 8)