JiT-B

Read the blog about the base model.

Scales are normalized around the center of the image

What's new

  • Flux2 RoPE calculation for both the image and text sequences
  • Text embeddings as a condition
  • Klein-like, bias-free modules and QK norm, shift-free AdaLN
  • Initialized from the base model, kept the CIELAB output
  • Masked image training

Future

  • Sequence separated spatial separation built into the tokenizer
  • Spatial tokenizer, leave the concept bleeding behind

Training

Despite its size, the model took hours to bake instead of days.

noisy = noise * (1 - t) + pixel_values * t
v_pred = model.forward(noisy, t, ctx)
v_target = pixel_values - noise
loss = torch.nn.functional.mse_loss(v_pred, v_target)

Inference

@torch.no_grad()
def inference(model: DiT, device=None, steps=50):
    tokenizer = AutoTokenizer.from_pretrained('nebulette/booru-character-aware-tokenizer')
    ctx = torch.tensor(tokenizer.encode('portrait')).unsqueeze(0).to(device)
    xt = torch.randn((1, 3, 48, 48), device=device)

    # Generate time steps from 0 to 1.
    time_steps = torch.linspace(0.0, 1.0, steps + 1, device=device)

    # Iterate through time steps.
    for t in time_steps:
        t = t.unsqueeze(0)
        # Predict the velocity at point (x_t, t) using the model.
        v_pred = model.forward(xt, t, ctx)

        # Update the state based on the predicted velocity.
        xt = xt + v_pred * (1 / steps)

    # Convert CIELAB → RGB.
    lab = torch.clamp(xt[0], -1, 1).cpu().numpy()
    L = (lab[0] + 1) * 50
    a = lab[1] * 128
    b = lab[2] * 128
    rgb = color.lab2rgb(np.stack([L, a, b], axis=-1)) * 255.0

    return Image.fromarray(rgb.astype(np.uint8))

References

  • 3588028.3603685
  • 2503.16397
  • 2508.02324 (Figure 8)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for nebulette/segmentation-aware-jit-b

Finetuned
(1)
this model