mb-diff-1000step / README.md
Ayushnangia's picture
Add model card
f3011df verified
metadata
license: apache-2.0
language:
  - en
library_name: transformers
tags:
  - modernbert
  - diffusion
  - masked-language-model
  - text-generation
base_model: answerdotai/ModernBERT-large
datasets:
  - Ayushnangia/dolci-diffusion-sft-0.9-passrate
pipeline_tag: fill-mask

mb-diff-1000step

A ModernBERT-large model fine-tuned as a diffusion language model (LLADA-style) for instruction-following.

Model Description

This model uses iterative unmasking for text generation:

  1. Start with a user prompt + fully masked response slots
  2. Model predicts all masked tokens simultaneously
  3. Keep the most confident prediction, repeat until done

Unlike autoregressive models, this allows parallel token prediction and flexible generation order.

Training Details

Usage

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

model_id = "Ayushnangia/mb-diff-1000step"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)

# Build prompt with masked response
query = "What is 2+2?"
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

# Add masked tokens for response
num_tokens = 64
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
mask_id = tokenizer.mask_token_id
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")

ids = [tokenizer.cls_token_id] + prompt_ids + [mask_id] * num_tokens + [im_end_id]

# Iterative unmasking
for step in range(num_tokens):
    with torch.no_grad():
        input_tensor = torch.tensor([ids], device=device)
        logits = model(input_ids=input_tensor).logits[0]

    probs = torch.softmax(logits, dim=-1)

    # Find mask positions
    mask_positions = [i for i, tok in enumerate(ids) if tok == mask_id]
    if not mask_positions:
        break

    # Get confidence for each mask position
    mask_probs = torch.zeros_like(probs)
    for pos in mask_positions:
        mask_probs[pos] = probs[pos]

    # Fill most confident prediction
    max_probs, max_tokens = mask_probs.max(dim=-1)
    best_pos = max_probs.argmax().item()
    ids[best_pos] = max_tokens[best_pos].item()

# Decode response
response_start = len(prompt_ids) + 1
response_ids = [t for t in ids[response_start:] if t not in (mask_id, im_end_id)]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
print(response)

Inference Script

For easier inference, use the sampling script from the training repo:

git clone https://github.com/agokrani/diffu-convert
cd diffu-convert
pip install -e .

python scripts/mb_dllm_sample.py --model Ayushnangia/mb-diff-1000step --query "What is 2+2?"

Limitations

  • Early checkpoint (1000 steps) - not fully converged
  • Best for short responses (64-256 tokens)
  • Math/reasoning tasks may have lower accuracy than autoregressive models

Citation

@misc{mb-diffusion,
  author = {Ayush Nangia},
  title = {ModernBERT Diffusion Language Model},
  year = {2025},
  publisher = {HuggingFace},
  url = {https://huggingface.co/Ayushnangia/mb-diff-1000step}
}