--- license: apache-2.0 language: - en library_name: transformers tags: - modernbert - diffusion - masked-language-model - text-generation base_model: answerdotai/ModernBERT-large datasets: - Ayushnangia/dolci-diffusion-sft-0.9-passrate pipeline_tag: fill-mask --- # mb-diff-1000step A ModernBERT-large model fine-tuned as a **diffusion language model** (LLADA-style) for instruction-following. ## Model Description This model uses iterative unmasking for text generation: 1. Start with a user prompt + fully masked response slots 2. Model predicts all masked tokens simultaneously 3. Keep the most confident prediction, repeat until done Unlike autoregressive models, this allows parallel token prediction and flexible generation order. ## Training Details - **Base model**: [answerdotai/ModernBERT-large](https://huggingface.co/answerdotai/ModernBERT-large) - **Training data**: [Ayushnangia/dolci-diffusion-sft-0.9-passrate](https://huggingface.co/datasets/Ayushnangia/dolci-diffusion-sft-0.9-passrate) (117k high-quality examples) - **Training steps**: 1000 - **Hardware**: H100 80GB - **Variable masking**: 15-99% of assistant tokens masked per sample ## Usage ```python import torch from transformers import AutoTokenizer, AutoModelForMaskedLM model_id = "Ayushnangia/mb-diff-1000step" device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForMaskedLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device) # Build prompt with masked response query = "What is 2+2?" messages = [{"role": "user", "content": query}] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Add masked tokens for response num_tokens = 64 prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) mask_id = tokenizer.mask_token_id im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") ids = [tokenizer.cls_token_id] + prompt_ids + [mask_id] * num_tokens + [im_end_id] # Iterative unmasking for step in range(num_tokens): with torch.no_grad(): input_tensor = torch.tensor([ids], device=device) logits = model(input_ids=input_tensor).logits[0] probs = torch.softmax(logits, dim=-1) # Find mask positions mask_positions = [i for i, tok in enumerate(ids) if tok == mask_id] if not mask_positions: break # Get confidence for each mask position mask_probs = torch.zeros_like(probs) for pos in mask_positions: mask_probs[pos] = probs[pos] # Fill most confident prediction max_probs, max_tokens = mask_probs.max(dim=-1) best_pos = max_probs.argmax().item() ids[best_pos] = max_tokens[best_pos].item() # Decode response response_start = len(prompt_ids) + 1 response_ids = [t for t in ids[response_start:] if t not in (mask_id, im_end_id)] response = tokenizer.decode(response_ids, skip_special_tokens=True) print(response) ``` ## Inference Script For easier inference, use the sampling script from the training repo: ```bash git clone https://github.com/agokrani/diffu-convert cd diffu-convert pip install -e . python scripts/mb_dllm_sample.py --model Ayushnangia/mb-diff-1000step --query "What is 2+2?" ``` ## Limitations - Early checkpoint (1000 steps) - not fully converged - Best for short responses (64-256 tokens) - Math/reasoning tasks may have lower accuracy than autoregressive models ## Citation ```bibtex @misc{mb-diffusion, author = {Ayush Nangia}, title = {ModernBERT Diffusion Language Model}, year = {2025}, publisher = {HuggingFace}, url = {https://huggingface.co/Ayushnangia/mb-diff-1000step} } ```