mb-diff-500step / README.md
Ayushnangia's picture
Add model card
ca5c942 verified
---
license: apache-2.0
language:
- en
library_name: transformers
tags:
- modernbert
- diffusion
- masked-language-model
- text-generation
base_model: answerdotai/ModernBERT-large
datasets:
- Ayushnangia/dolci-diffusion-sft-0.9-passrate
pipeline_tag: fill-mask
---
# mb-diff-500step
A ModernBERT-large model fine-tuned as a **diffusion language model** (LLADA-style) for instruction-following.
## Model Description
This model uses iterative unmasking for text generation:
1. Start with a user prompt + fully masked response slots
2. Model predicts all masked tokens simultaneously
3. Keep the most confident prediction, repeat until done
Unlike autoregressive models, this allows parallel token prediction and flexible generation order.
## Training Details
- **Base model**: [answerdotai/ModernBERT-large](https://huggingface.co/answerdotai/ModernBERT-large)
- **Training data**: [Ayushnangia/dolci-diffusion-sft-0.9-passrate](https://huggingface.co/datasets/Ayushnangia/dolci-diffusion-sft-0.9-passrate) (117k high-quality examples)
- **Training steps**: 500
- **Hardware**: H100 80GB
- **Variable masking**: 15-99% of assistant tokens masked per sample
## Usage
```python
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
model_id = "Ayushnangia/mb-diff-500step"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
# Build prompt with masked response
query = "What is 2+2?"
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Add masked tokens for response
num_tokens = 64
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
mask_id = tokenizer.mask_token_id
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
ids = [tokenizer.cls_token_id] + prompt_ids + [mask_id] * num_tokens + [im_end_id]
# Iterative unmasking
for step in range(num_tokens):
with torch.no_grad():
input_tensor = torch.tensor([ids], device=device)
logits = model(input_ids=input_tensor).logits[0]
probs = torch.softmax(logits, dim=-1)
# Find mask positions
mask_positions = [i for i, tok in enumerate(ids) if tok == mask_id]
if not mask_positions:
break
# Get confidence for each mask position
mask_probs = torch.zeros_like(probs)
for pos in mask_positions:
mask_probs[pos] = probs[pos]
# Fill most confident prediction
max_probs, max_tokens = mask_probs.max(dim=-1)
best_pos = max_probs.argmax().item()
ids[best_pos] = max_tokens[best_pos].item()
# Decode response
response_start = len(prompt_ids) + 1
response_ids = [t for t in ids[response_start:] if t not in (mask_id, im_end_id)]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
print(response)
```
## Inference Script
For easier inference, use the sampling script from the training repo:
```bash
git clone https://github.com/agokrani/diffu-convert
cd diffu-convert
pip install -e .
python scripts/mb_dllm_sample.py --model Ayushnangia/mb-diff-500step --query "What is 2+2?"
```
## Limitations
- Early checkpoint (500 steps) - not fully converged
- Best for short responses (64-256 tokens)
- Math/reasoning tasks may have lower accuracy than autoregressive models
## Citation
```bibtex
@misc{mb-diffusion,
author = {Ayush Nangia},
title = {ModernBERT Diffusion Language Model},
year = {2025},
publisher = {HuggingFace},
url = {https://huggingface.co/Ayushnangia/mb-diff-500step}
}
```