mb-diff-500step
A ModernBERT-large model fine-tuned as a diffusion language model (LLADA-style) for instruction-following.
Model Description
This model uses iterative unmasking for text generation:
- Start with a user prompt + fully masked response slots
- Model predicts all masked tokens simultaneously
- Keep the most confident prediction, repeat until done
Unlike autoregressive models, this allows parallel token prediction and flexible generation order.
Training Details
- Base model: answerdotai/ModernBERT-large
- Training data: Ayushnangia/dolci-diffusion-sft-0.9-passrate (117k high-quality examples)
- Training steps: 500
- Hardware: H100 80GB
- Variable masking: 15-99% of assistant tokens masked per sample
Usage
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
model_id = "Ayushnangia/mb-diff-500step"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
# Build prompt with masked response
query = "What is 2+2?"
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Add masked tokens for response
num_tokens = 64
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
mask_id = tokenizer.mask_token_id
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
ids = [tokenizer.cls_token_id] + prompt_ids + [mask_id] * num_tokens + [im_end_id]
# Iterative unmasking
for step in range(num_tokens):
with torch.no_grad():
input_tensor = torch.tensor([ids], device=device)
logits = model(input_ids=input_tensor).logits[0]
probs = torch.softmax(logits, dim=-1)
# Find mask positions
mask_positions = [i for i, tok in enumerate(ids) if tok == mask_id]
if not mask_positions:
break
# Get confidence for each mask position
mask_probs = torch.zeros_like(probs)
for pos in mask_positions:
mask_probs[pos] = probs[pos]
# Fill most confident prediction
max_probs, max_tokens = mask_probs.max(dim=-1)
best_pos = max_probs.argmax().item()
ids[best_pos] = max_tokens[best_pos].item()
# Decode response
response_start = len(prompt_ids) + 1
response_ids = [t for t in ids[response_start:] if t not in (mask_id, im_end_id)]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
print(response)
Inference Script
For easier inference, use the sampling script from the training repo:
git clone https://github.com/agokrani/diffu-convert
cd diffu-convert
pip install -e .
python scripts/mb_dllm_sample.py --model Ayushnangia/mb-diff-500step --query "What is 2+2?"
Limitations
- Early checkpoint (500 steps) - not fully converged
- Best for short responses (64-256 tokens)
- Math/reasoning tasks may have lower accuracy than autoregressive models
Citation
@misc{mb-diffusion,
author = {Ayush Nangia},
title = {ModernBERT Diffusion Language Model},
year = {2025},
publisher = {HuggingFace},
url = {https://huggingface.co/Ayushnangia/mb-diff-500step}
}
- Downloads last month
- 20
Model tree for Ayushnangia/mb-diff-500step
Base model
answerdotai/ModernBERT-large