You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

ModernBERT Discrete Diffusion FT (W.I.P)

Model Summary

This model represents an implementation of Order-Agnostic Autoregressive Diffusion (OA-ARD), adapted for the ModernBERT architecture. Unlike standard autoregressive (AR) models that generate tokens sequentially from left to right, this model treats text generation as an iterative denoising process in a discrete state space. By fine-tuning an encoder-only architecture to reconstruct masked tokens across a probabilistic noise schedule, this model achieves parallel decoding, bidirectional context awareness, and structural compliance superior to causal LLMs of similar size.

Intended Use

This model is designed for tasks requiring:

  • Structured Content Extraction: Filling complex schemas (e.g., JSON) where the structure is known but the content must be extracted from context.
  • Code Infilling and Refactoring: Generating code in the middle of a file while attending to both the preceding function signature and the subsequent return statements.
  • Constrained Generation: Generating text that must strictly adhere to a specific format or length.

Model Details

  • Architecture: ModernBERT-base (Encoder-only, 8k context window).
  • Training Objective: Absorbing State Diffusion (Denoising Cross-Entropy).
  • Noise Schedule: Cosine schedule.
  • Inference Mechanism: Iterative "Mask-Predict" with Classifier-Free Guidance (CFG).

Training Data

The model was fine-tuned on a mixed corpus designed to balance reasoning capabilities with structural rigor:

  • 70% FineWeb-Edu (Sample 10BT): High-quality educational web text to maintain linguistic coherence and reasoning.
  • 30% The Stack (Dedup): Code (Python) and Data (JSON) subsets to enforce strict syntactic structure and dependency handling.

Usage

This model requires a custom inference engine that manages the diffusion schedule.

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForMaskedLM

class MaskedDiffusionEngine:
    def __init__(self, model_name="philipp-zettl/modernbert-diffusion-ft"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(model_name).to(self.device)
        self.mask_token_id = self.tokenizer.mask_token_id

    def generate(self, prompt, num_new_tokens=20, steps=15, guidance_scale=1.5):
        # Prepare Inputs (Conditional and Unconditional)
        text = prompt + " " + (self.tokenizer.mask_token * num_new_tokens)
        inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
        input_ids = inputs["input_ids"]
        
        # Identify masks
        mask_indices = (input_ids == self.mask_token_id).nonzero(as_tuple=True)[1]
        
        # Iterative Denoising Loop
        for step in range(steps):
            with torch.no_grad():
                # Conditional Pass
                logits = self.model(input_ids).logits
                
                # Classifier-Free Guidance (Simplified Implementation)
                if guidance_scale > 1.0:
                    uncond_input = input_ids.clone()
                    # Mask prompt context
                    special_mask = (uncond_input == self.tokenizer.cls_token_id) | \
                                   (uncond_input == self.tokenizer.sep_token_id)
                    uncond_input[~special_mask] = self.mask_token_id
                    uncond_logits = self.model(uncond_input).logits
                    logits = uncond_logits + guidance_scale * (logits - uncond_logits)

            # Sampling and Update logic (omitted for brevity, see repo)
            # ...
            
        return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)

Performance Benchmarks

We compared this architecture against Qwen2.5-0.5B and distilgpt2 on structured extraction and infilling tasks.

1. Latency Scaling

Unlike AR models where latency scales linearly with output length (O(N)), Discrete Diffusion scales with the number of refinement steps (O(K)).

Output Length AR Time (s) Diffusion Time (s) Speedup
16 Tokens 0.43 0.20 2.1x
64 Tokens 1.60 0.22 7.3x
128 Tokens 3.24 0.25 12.8x

2. Structural Validity (Complex JSON)

On a task requiring the extraction of 20+ fields into a rigid JSON schema:

  • Autoregressive Baseline: 0.0% Validity (Frequently missed closing braces or hallucinated keys).
  • ModernBERT Diffusion: 100.0% Validity (Structure is strictly enforced via the prompt template).

Limitations

  • World Knowledge: As an encoder model trained primarily for understanding, it hallucinates factual knowledge (e.g., capitals of countries) more frequently than causal LLMs of equivalent size.
  • Compute Density: While latency is lower for long sequences, total FLOPs may be higher due to bidirectional attention over the full sequence at every step.
  • Prompt Dependency: The model relies heavily on the quality of the prompt structure. For extraction, providing a complete JSON template with [MASK] tokens is required for optimal performance.

Citation

If you use this model or codebase, please cite the associated repository and the original ModernBERT paper

Downloads last month
-
Safetensors
Model size
0.1B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for philipp-zettl/modernbert-diffusion-ft

Finetuned
(1122)
this model

Dataset used to train philipp-zettl/modernbert-diffusion-ft