--- title: CDD - Constrained Discrete Diffusion emoji: 🧬 colorFrom: green colorTo: blue sdk: docker app_port: 7860 pinned: false --- # Constrained Discrete Diffusion (CDD) Implementation of **"Constrained Language Generation with Discrete Diffusion Models"** (Cardei et al., 2025 β€” [arXiv:2503.09790v3](https://arxiv.org/abs/2503.09790v3)). CDD integrates discrete diffusion models with differentiable optimization to enforce constraints on text generation **without retraining**. It achieves **zero constraint violations** across toxicity mitigation, molecular generation, and instruction following. ## Key Idea Discrete diffusion models (MDLM, UDLM) generate sequences by iteratively denoising from a noise distribution. At each reverse step, CDD inserts an **Augmented Lagrangian Method (ALM) projection** that modifies the denoiser's predicted token distributions to satisfy user-defined constraints: ``` z_T β†’ x_ΞΈ(z_T) β†’ [CDD Projection] β†’ z_{T-1} β†’ ... β†’ z_0 ``` The projection solves: ``` min_y D_KL(x_ΞΈ || y) s.t. g(argmax(y)) ∈ C ``` using Gumbel-Softmax relaxation for differentiability and ALM for constraint enforcement. ## Architecture ``` cdd/ β”œβ”€β”€ samplers/ β”‚ └── cdd_sampler.py # Core ALM projection + sampling loop (Β§4) β”œβ”€β”€ constraints/ β”‚ β”œβ”€β”€ toxicity.py # Toxicity mitigation constraint (Β§5.1) β”‚ β”œβ”€β”€ molecular.py # SA + Novelty constraints (Β§5.2) β”‚ └── instruction.py # Counting + Lexical constraints (Β§5.3) β”œβ”€β”€ models/ β”‚ └── load_models.py # MDLM/UDLM model loading utilities β”œβ”€β”€ experiments/ β”‚ β”œβ”€β”€ toxicity_mitigation.py # Full toxicity experiment β”‚ β”œβ”€β”€ molecular_generation.py # Full molecular experiment β”‚ β”œβ”€β”€ instruction_following.py # Full instruction experiment β”‚ └── train_surrogates.py # Surrogate model training └── utils/ β”œβ”€β”€ noise_schedule.py # Diffusion math (schedules, posteriors) └── evaluation.py # Metrics (PPL, entropy, violation rate) ``` ## Installation ```bash pip install -e . # For GPU models (required for MDLM/UDLM inference): pip install flash-attn --no-build-isolation # For molecular evaluation: pip install rdkit-pypi ``` ## Quick Start ### 1. Unconstrained Generation with MDLM ```python import torch from cdd.models import load_mdlm from cdd.samplers import CDDSampler, ALMConfig # Load pretrained MDLM model, tokenizer, info = load_mdlm(device="cuda") # Create sampler (no constraints) sampler = CDDSampler( model=model, tokenizer=tokenizer, constraint_fn=lambda x: torch.tensor(0.0), # No constraint diffusion_type="mdlm", num_timesteps=1000, seq_length=128, device="cuda", ) result = sampler.sample(batch_size=1) print(result["text"][0]) ``` ### 2. Toxicity-Constrained Generation ```python from cdd.samplers import CDDSampler, ALMConfig, TOXICITY_ALM_CONFIG # Setup toxicity constraint (Β§5.1) from cdd.constraints import create_toxicity_constraint constraint = create_toxicity_constraint(threshold=0.5) sampler = CDDSampler( model=model, tokenizer=tokenizer, constraint_fn=constraint, alm_config=TOXICITY_ALM_CONFIG, # Ξ»=0, ΞΌ=1, K=1000, M=10, Ξ·=0.2 diffusion_type="mdlm", device="cuda", ) # Conditional generation from toxic prompt prefix = tokenizer.encode("The politician was accused of", return_tensors="pt").to("cuda") result = sampler.sample(batch_size=1, prefix_ids=prefix) print(result["text"][0]) # Non-toxic completion ``` ### 3. Molecular Generation with SA Constraint ```python from cdd.models import load_udlm from cdd.samplers import CDDSampler, MOLECULAR_SA_ALM_CONFIG model, tokenizer, info = load_udlm("kuleshov-group/udlm-qm9", device="cuda") sampler = CDDSampler( model=model, tokenizer=tokenizer, constraint_fn=sa_constraint, # SA ≀ 3.5 alm_config=MOLECULAR_SA_ALM_CONFIG, # Ξ»=0, ΞΌ=1, ΞΌ_max=1000, K=1000, M=100, Ξ·=1.0 diffusion_type="udlm", num_timesteps=1000, seq_length=32, device="cuda", ) result = sampler.sample(batch_size=32) ``` ## Reproducing Paper Experiments ### Toxicity Mitigation (Β§5.1, Table 3) ```bash python -m cdd.experiments.toxicity_mitigation \ --threshold 0.5 \ --num_samples 1000 \ --device cuda ``` Expected results (CDD Ο„=0.50): PPL β‰ˆ 59.44, Violation = 0.0% ### Molecular Generation (Β§5.2, Table 4) ```bash python -m cdd.experiments.molecular_generation \ --sa_threshold 3.5 \ --num_samples 1000 \ --device cuda ``` ### Instruction Following (Β§5.3, Table 5) ```bash # Counting task python -m cdd.experiments.instruction_following \ --task counting --num_samples 100 # Lexical task python -m cdd.experiments.instruction_following \ --task lexical --num_samples 100 ``` ### Train Surrogate Models ```bash # Toxicity surrogate (GPT-Neo 1.3B on Jigsaw) python -m cdd.experiments.train_surrogates --task toxicity # SA surrogate (GPT-2 on QM9) python -m cdd.experiments.train_surrogates --task sa ``` ## ALM Hyperparameters (Appendix B & C) | Task | Ξ»β‚€ | ΞΌβ‚€ | ΞΌ_max | K (outer) | M (inner) | Ξ· | |------|-----|-----|-------|-----------|-----------|------| | Toxicity (NLP) | 0.0 | 1.0 | β€” | 1000 | 10 | 0.20 | | Molecular (QM9) | 0.0 | 1.0 | 1000 | 1000 | 100 | 1.0 | ## Base Models | Model | Type | Params | Vocab | Length | Trained On | |-------|------|--------|-------|--------|------------| | [kuleshov-group/mdlm-owt](https://huggingface.co/kuleshov-group/mdlm-owt) | MDLM | 110M | 50,258 | 1024 | OpenWebText | | [kuleshov-group/udlm-qm9](https://huggingface.co/kuleshov-group/udlm-qm9) | UDLM | 92M | 40 | 32 | QM9 | | [kuleshov-group/udlm-lm1b](https://huggingface.co/kuleshov-group/udlm-lm1b) | UDLM | 139M | 30,522 | 128 | LM1B | ## Method Details ### Gumbel-Softmax Relaxation (Eq. 6) Makes `argmax` differentiable: ``` Ο†Μƒ(x)(v) = exp((log x(v) + ΞΎ_v) / T) / Ξ£ exp((log x(v') + ΞΎ_v') / T) ``` ### Augmented Lagrangian (Eq. 7-8) Optimization at each diffusion step: ``` L(y, Ξ», ΞΌ) = D_KL(x_ΞΈ || y) + λ·g(α»Ή) + (ΞΌ/2)Β·g(α»Ή)Β² ``` Update rules: ``` Ξ» ← Ξ» + ΞΌΒ·g(α»Ή) ΞΌ ← min(2ΞΌ, ΞΌ_max) ``` ### MDLM vs UDLM for CDD | Property | MDLM | UDLM | |----------|------|------| | Noise type | Absorbing ([MASK]) | Uniform | | Tokens re-sampled each step | Only masked | All | | CDD paper usage | Text tasks | Molecular | | Time conditioning | No | Yes | ## Citation ```bibtex @article{cardei2025constrained, title={Constrained Language Generation with Discrete Diffusion Models}, author={Cardei, Michael and Christopher, Jacob K and Hartvigsen, Thomas and Bartoldson, Brian R. and Kailkhura, Bhavya and Fioretto, Ferdinando}, journal={arXiv preprint arXiv:2503.09790}, year={2025} } ```