mb-diff-500step

A ModernBERT-large model fine-tuned as a diffusion language model (LLADA-style) for instruction-following.

Model Description

This model uses iterative unmasking for text generation:

  1. Start with a user prompt + fully masked response slots
  2. Model predicts all masked tokens simultaneously
  3. Keep the most confident prediction, repeat until done

Unlike autoregressive models, this allows parallel token prediction and flexible generation order.

Training Details

Usage

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

model_id = "Ayushnangia/mb-diff-500step"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)

# Build prompt with masked response
query = "What is 2+2?"
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

# Add masked tokens for response
num_tokens = 64
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
mask_id = tokenizer.mask_token_id
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")

ids = [tokenizer.cls_token_id] + prompt_ids + [mask_id] * num_tokens + [im_end_id]

# Iterative unmasking
for step in range(num_tokens):
    with torch.no_grad():
        input_tensor = torch.tensor([ids], device=device)
        logits = model(input_ids=input_tensor).logits[0]

    probs = torch.softmax(logits, dim=-1)

    # Find mask positions
    mask_positions = [i for i, tok in enumerate(ids) if tok == mask_id]
    if not mask_positions:
        break

    # Get confidence for each mask position
    mask_probs = torch.zeros_like(probs)
    for pos in mask_positions:
        mask_probs[pos] = probs[pos]

    # Fill most confident prediction
    max_probs, max_tokens = mask_probs.max(dim=-1)
    best_pos = max_probs.argmax().item()
    ids[best_pos] = max_tokens[best_pos].item()

# Decode response
response_start = len(prompt_ids) + 1
response_ids = [t for t in ids[response_start:] if t not in (mask_id, im_end_id)]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
print(response)

Inference Script

For easier inference, use the sampling script from the training repo:

git clone https://github.com/agokrani/diffu-convert
cd diffu-convert
pip install -e .

python scripts/mb_dllm_sample.py --model Ayushnangia/mb-diff-500step --query "What is 2+2?"

Limitations

  • Early checkpoint (500 steps) - not fully converged
  • Best for short responses (64-256 tokens)
  • Math/reasoning tasks may have lower accuracy than autoregressive models

Citation

@misc{mb-diffusion,
  author = {Ayush Nangia},
  title = {ModernBERT Diffusion Language Model},
  year = {2025},
  publisher = {HuggingFace},
  url = {https://huggingface.co/Ayushnangia/mb-diff-500step}
}
Downloads last month
20
Safetensors
Model size
0.4B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Ayushnangia/mb-diff-500step

Finetuned
(235)
this model

Dataset used to train Ayushnangia/mb-diff-500step