| --- |
| license: apache-2.0 |
| base_model: |
| - ruwwww/waifu_diffusion |
| pipeline_tag: text-to-image |
| --- |
| |
| ## JiT-B |
|
|
| Read the [blog](https://ruwwww.github.io/al-folio/blog/2026/waifu-diffusion/) about the base model. |
|
|
|  |
|
|
| *Scales are normalized around the center of the image* |
|
|
| ## What's new |
|
|
| - Flux2 RoPE calculation for both the image and text sequences |
| - [Text embeddings](https://huggingface.co/nebulette/booru-character-aware-tokenizer) as a condition |
| - [Klein-like](https://huggingface.co/nebulette/klein-x), bias-free modules and QK norm, shift-free AdaLN |
| - Initialized from the [base model](https://huggingface.co/ruwwww/waifu_diffusion), kept the CIELAB output |
| - Masked image training |
|
|
| ## Future |
|
|
| - Sequence separated spatial separation built into the tokenizer |
| - Spatial tokenizer, leave the concept bleeding behind |
|
|
| ## Training |
|
|
| Despite its size, the model took hours to bake instead of [days](https://huggingface.co/kaupane/DiT-Wikiart-Base). |
|
|
| ```python |
| noisy = noise * (1 - t) + pixel_values * t |
| v_pred = model.forward(noisy, t, ctx) |
| v_target = pixel_values - noise |
| loss = torch.nn.functional.mse_loss(v_pred, v_target) |
| ``` |
|
|
| ## Inference |
|
|
| ```python |
| @torch.no_grad() |
| def inference(model: DiT, device=None, steps=50): |
| tokenizer = AutoTokenizer.from_pretrained('nebulette/booru-character-aware-tokenizer') |
| ctx = torch.tensor(tokenizer.encode('portrait')).unsqueeze(0).to(device) |
| xt = torch.randn((1, 3, 48, 48), device=device) |
| |
| # Generate time steps from 0 to 1. |
| time_steps = torch.linspace(0.0, 1.0, steps + 1, device=device) |
| |
| # Iterate through time steps. |
| for t in time_steps: |
| t = t.unsqueeze(0) |
| # Predict the velocity at point (x_t, t) using the model. |
| v_pred = model.forward(xt, t, ctx) |
| |
| # Update the state based on the predicted velocity. |
| xt = xt + v_pred * (1 / steps) |
| |
| # Convert CIELAB → RGB. |
| lab = torch.clamp(xt[0], -1, 1).cpu().numpy() |
| L = (lab[0] + 1) * 50 |
| a = lab[1] * 128 |
| b = lab[2] * 128 |
| rgb = color.lab2rgb(np.stack([L, a, b], axis=-1)) * 255.0 |
| |
| return Image.fromarray(rgb.astype(np.uint8)) |
| ``` |
|
|
| ## References |
|
|
| - 3588028.3603685 |
| - 2503.16397 |
| - 2508.02324 (Figure 8) |
|
|