ModernBERT Discrete Diffusion FT (W.I.P)
Model Summary
This model represents an implementation of Order-Agnostic Autoregressive Diffusion (OA-ARD), adapted for the ModernBERT architecture. Unlike standard autoregressive (AR) models that generate tokens sequentially from left to right, this model treats text generation as an iterative denoising process in a discrete state space. By fine-tuning an encoder-only architecture to reconstruct masked tokens across a probabilistic noise schedule, this model achieves parallel decoding, bidirectional context awareness, and structural compliance superior to causal LLMs of similar size.
Intended Use
This model is designed for tasks requiring:
- Structured Content Extraction: Filling complex schemas (e.g., JSON) where the structure is known but the content must be extracted from context.
- Code Infilling and Refactoring: Generating code in the middle of a file while attending to both the preceding function signature and the subsequent return statements.
- Constrained Generation: Generating text that must strictly adhere to a specific format or length.
Model Details
- Architecture: ModernBERT-base (Encoder-only, 8k context window).
- Training Objective: Absorbing State Diffusion (Denoising Cross-Entropy).
- Noise Schedule: Cosine schedule.
- Inference Mechanism: Iterative "Mask-Predict" with Classifier-Free Guidance (CFG).
Training Data
The model was fine-tuned on a mixed corpus designed to balance reasoning capabilities with structural rigor:
- 70% FineWeb-Edu (Sample 10BT): High-quality educational web text to maintain linguistic coherence and reasoning.
- 30% The Stack (Dedup): Code (Python) and Data (JSON) subsets to enforce strict syntactic structure and dependency handling.
Usage
This model requires a custom inference engine that manages the diffusion schedule.
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForMaskedLM
class MaskedDiffusionEngine:
def __init__(self, model_name="philipp-zettl/modernbert-diffusion-ft"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForMaskedLM.from_pretrained(model_name).to(self.device)
self.mask_token_id = self.tokenizer.mask_token_id
def generate(self, prompt, num_new_tokens=20, steps=15, guidance_scale=1.5):
# Prepare Inputs (Conditional and Unconditional)
text = prompt + " " + (self.tokenizer.mask_token * num_new_tokens)
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
input_ids = inputs["input_ids"]
# Identify masks
mask_indices = (input_ids == self.mask_token_id).nonzero(as_tuple=True)[1]
# Iterative Denoising Loop
for step in range(steps):
with torch.no_grad():
# Conditional Pass
logits = self.model(input_ids).logits
# Classifier-Free Guidance (Simplified Implementation)
if guidance_scale > 1.0:
uncond_input = input_ids.clone()
# Mask prompt context
special_mask = (uncond_input == self.tokenizer.cls_token_id) | \
(uncond_input == self.tokenizer.sep_token_id)
uncond_input[~special_mask] = self.mask_token_id
uncond_logits = self.model(uncond_input).logits
logits = uncond_logits + guidance_scale * (logits - uncond_logits)
# Sampling and Update logic (omitted for brevity, see repo)
# ...
return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
Performance Benchmarks
We compared this architecture against Qwen2.5-0.5B and distilgpt2 on structured extraction and infilling tasks.
1. Latency Scaling
Unlike AR models where latency scales linearly with output length (O(N)), Discrete Diffusion scales with the number of refinement steps (O(K)).
| Output Length | AR Time (s) | Diffusion Time (s) | Speedup |
|---|---|---|---|
| 16 Tokens | 0.43 | 0.20 | 2.1x |
| 64 Tokens | 1.60 | 0.22 | 7.3x |
| 128 Tokens | 3.24 | 0.25 | 12.8x |
2. Structural Validity (Complex JSON)
On a task requiring the extraction of 20+ fields into a rigid JSON schema:
- Autoregressive Baseline: 0.0% Validity (Frequently missed closing braces or hallucinated keys).
- ModernBERT Diffusion: 100.0% Validity (Structure is strictly enforced via the prompt template).
Limitations
- World Knowledge: As an encoder model trained primarily for understanding, it hallucinates factual knowledge (e.g., capitals of countries) more frequently than causal LLMs of equivalent size.
- Compute Density: While latency is lower for long sequences, total FLOPs may be higher due to bidirectional attention over the full sequence at every step.
- Prompt Dependency: The model relies heavily on the quality of the prompt structure. For extraction, providing a complete JSON template with [MASK] tokens is required for optimal performance.
Citation
If you use this model or codebase, please cite the associated repository and the original ModernBERT paper
- Downloads last month
- -
Model tree for philipp-zettl/modernbert-diffusion-ft
Base model
answerdotai/ModernBERT-base