File size: 3,633 Bytes
cb08785
ca5c942
 
 
cb08785
ca5c942
 
 
 
 
 
 
 
 
cb08785
 
ca5c942
cb08785
ca5c942
cb08785
ca5c942
cb08785
ca5c942
 
 
 
cb08785
ca5c942
cb08785
 
 
ca5c942
 
 
 
 
cb08785
ca5c942
cb08785
ca5c942
 
 
cb08785
ca5c942
 
cb08785
ca5c942
 
cb08785
ca5c942
 
 
 
cb08785
ca5c942
 
 
 
 
cb08785
ca5c942
cb08785
ca5c942
 
 
 
 
cb08785
ca5c942
cb08785
ca5c942
 
 
 
cb08785
ca5c942
 
 
 
cb08785
ca5c942
 
 
 
cb08785
ca5c942
 
 
 
 
 
cb08785
ca5c942
cb08785
ca5c942
cb08785
ca5c942
 
 
 
cb08785
ca5c942
 
cb08785
ca5c942
cb08785
ca5c942
 
 
cb08785
ca5c942
cb08785
ca5c942
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
---
license: apache-2.0
language:
- en
library_name: transformers
tags:
- modernbert
- diffusion
- masked-language-model
- text-generation
base_model: answerdotai/ModernBERT-large
datasets:
- Ayushnangia/dolci-diffusion-sft-0.9-passrate
pipeline_tag: fill-mask
---

# mb-diff-500step

A ModernBERT-large model fine-tuned as a **diffusion language model** (LLADA-style) for instruction-following.

## Model Description

This model uses iterative unmasking for text generation:
1. Start with a user prompt + fully masked response slots
2. Model predicts all masked tokens simultaneously
3. Keep the most confident prediction, repeat until done

Unlike autoregressive models, this allows parallel token prediction and flexible generation order.

## Training Details

- **Base model**: [answerdotai/ModernBERT-large](https://huggingface.co/answerdotai/ModernBERT-large)
- **Training data**: [Ayushnangia/dolci-diffusion-sft-0.9-passrate](https://huggingface.co/datasets/Ayushnangia/dolci-diffusion-sft-0.9-passrate) (117k high-quality examples)
- **Training steps**: 500
- **Hardware**: H100 80GB
- **Variable masking**: 15-99% of assistant tokens masked per sample

## Usage

```python
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

model_id = "Ayushnangia/mb-diff-500step"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)

# Build prompt with masked response
query = "What is 2+2?"
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

# Add masked tokens for response
num_tokens = 64
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
mask_id = tokenizer.mask_token_id
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")

ids = [tokenizer.cls_token_id] + prompt_ids + [mask_id] * num_tokens + [im_end_id]

# Iterative unmasking
for step in range(num_tokens):
    with torch.no_grad():
        input_tensor = torch.tensor([ids], device=device)
        logits = model(input_ids=input_tensor).logits[0]

    probs = torch.softmax(logits, dim=-1)

    # Find mask positions
    mask_positions = [i for i, tok in enumerate(ids) if tok == mask_id]
    if not mask_positions:
        break

    # Get confidence for each mask position
    mask_probs = torch.zeros_like(probs)
    for pos in mask_positions:
        mask_probs[pos] = probs[pos]

    # Fill most confident prediction
    max_probs, max_tokens = mask_probs.max(dim=-1)
    best_pos = max_probs.argmax().item()
    ids[best_pos] = max_tokens[best_pos].item()

# Decode response
response_start = len(prompt_ids) + 1
response_ids = [t for t in ids[response_start:] if t not in (mask_id, im_end_id)]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
print(response)
```

## Inference Script

For easier inference, use the sampling script from the training repo:

```bash
git clone https://github.com/agokrani/diffu-convert
cd diffu-convert
pip install -e .

python scripts/mb_dllm_sample.py --model Ayushnangia/mb-diff-500step --query "What is 2+2?"
```

## Limitations

- Early checkpoint (500 steps) - not fully converged
- Best for short responses (64-256 tokens)
- Math/reasoning tasks may have lower accuracy than autoregressive models

## Citation

```bibtex
@misc{mb-diffusion,
  author = {Ayush Nangia},
  title = {ModernBERT Diffusion Language Model},
  year = {2025},
  publisher = {HuggingFace},
  url = {https://huggingface.co/Ayushnangia/mb-diff-500step}
}
```