| --- |
| title: CDD - Constrained Discrete Diffusion |
| emoji: 🧬 |
| colorFrom: green |
| colorTo: blue |
| sdk: docker |
| app_port: 7860 |
| pinned: false |
| --- |
| |
| # Constrained Discrete Diffusion (CDD) |
|
|
| Implementation of **"Constrained Language Generation with Discrete Diffusion Models"** (Cardei et al., 2025 — [arXiv:2503.09790v3](https://arxiv.org/abs/2503.09790v3)). |
|
|
| CDD integrates discrete diffusion models with differentiable optimization to enforce constraints on text generation **without retraining**. It achieves **zero constraint violations** across toxicity mitigation, molecular generation, and instruction following. |
|
|
| ## Key Idea |
|
|
| Discrete diffusion models (MDLM, UDLM) generate sequences by iteratively denoising from a noise distribution. At each reverse step, CDD inserts an **Augmented Lagrangian Method (ALM) projection** that modifies the denoiser's predicted token distributions to satisfy user-defined constraints: |
|
|
| ``` |
| z_T → x_θ(z_T) → [CDD Projection] → z_{T-1} → ... → z_0 |
| ``` |
|
|
| The projection solves: |
| ``` |
| min_y D_KL(x_θ || y) s.t. g(argmax(y)) ∈ C |
| ``` |
|
|
| using Gumbel-Softmax relaxation for differentiability and ALM for constraint enforcement. |
|
|
| ## Architecture |
|
|
| ``` |
| cdd/ |
| ├── samplers/ |
| │ └── cdd_sampler.py # Core ALM projection + sampling loop (§4) |
| ├── constraints/ |
| │ ├── toxicity.py # Toxicity mitigation constraint (§5.1) |
| │ ├── molecular.py # SA + Novelty constraints (§5.2) |
| │ └── instruction.py # Counting + Lexical constraints (§5.3) |
| ├── models/ |
| │ └── load_models.py # MDLM/UDLM model loading utilities |
| ├── experiments/ |
| │ ├── toxicity_mitigation.py # Full toxicity experiment |
| │ ├── molecular_generation.py # Full molecular experiment |
| │ ├── instruction_following.py # Full instruction experiment |
| │ └── train_surrogates.py # Surrogate model training |
| └── utils/ |
| ├── noise_schedule.py # Diffusion math (schedules, posteriors) |
| └── evaluation.py # Metrics (PPL, entropy, violation rate) |
| ``` |
|
|
| ## Installation |
|
|
| ```bash |
| pip install -e . |
| |
| # For GPU models (required for MDLM/UDLM inference): |
| pip install flash-attn --no-build-isolation |
| |
| # For molecular evaluation: |
| pip install rdkit-pypi |
| ``` |
|
|
| ## Quick Start |
|
|
| ### 1. Unconstrained Generation with MDLM |
|
|
| ```python |
| import torch |
| from cdd.models import load_mdlm |
| from cdd.samplers import CDDSampler, ALMConfig |
| |
| # Load pretrained MDLM |
| model, tokenizer, info = load_mdlm(device="cuda") |
| |
| # Create sampler (no constraints) |
| sampler = CDDSampler( |
| model=model, |
| tokenizer=tokenizer, |
| constraint_fn=lambda x: torch.tensor(0.0), # No constraint |
| diffusion_type="mdlm", |
| num_timesteps=1000, |
| seq_length=128, |
| device="cuda", |
| ) |
| |
| result = sampler.sample(batch_size=1) |
| print(result["text"][0]) |
| ``` |
|
|
| ### 2. Toxicity-Constrained Generation |
|
|
| ```python |
| from cdd.samplers import CDDSampler, ALMConfig, TOXICITY_ALM_CONFIG |
| |
| # Setup toxicity constraint (§5.1) |
| from cdd.constraints import create_toxicity_constraint |
| constraint = create_toxicity_constraint(threshold=0.5) |
| |
| sampler = CDDSampler( |
| model=model, |
| tokenizer=tokenizer, |
| constraint_fn=constraint, |
| alm_config=TOXICITY_ALM_CONFIG, # λ=0, μ=1, K=1000, M=10, η=0.2 |
| diffusion_type="mdlm", |
| device="cuda", |
| ) |
| |
| # Conditional generation from toxic prompt |
| prefix = tokenizer.encode("The politician was accused of", return_tensors="pt").to("cuda") |
| result = sampler.sample(batch_size=1, prefix_ids=prefix) |
| print(result["text"][0]) # Non-toxic completion |
| ``` |
|
|
| ### 3. Molecular Generation with SA Constraint |
|
|
| ```python |
| from cdd.models import load_udlm |
| from cdd.samplers import CDDSampler, MOLECULAR_SA_ALM_CONFIG |
| |
| model, tokenizer, info = load_udlm("kuleshov-group/udlm-qm9", device="cuda") |
| |
| sampler = CDDSampler( |
| model=model, |
| tokenizer=tokenizer, |
| constraint_fn=sa_constraint, # SA ≤ 3.5 |
| alm_config=MOLECULAR_SA_ALM_CONFIG, # λ=0, μ=1, μ_max=1000, K=1000, M=100, η=1.0 |
| diffusion_type="udlm", |
| num_timesteps=1000, |
| seq_length=32, |
| device="cuda", |
| ) |
| |
| result = sampler.sample(batch_size=32) |
| ``` |
|
|
| ## Reproducing Paper Experiments |
|
|
| ### Toxicity Mitigation (§5.1, Table 3) |
|
|
| ```bash |
| python -m cdd.experiments.toxicity_mitigation \ |
| --threshold 0.5 \ |
| --num_samples 1000 \ |
| --device cuda |
| ``` |
|
|
| Expected results (CDD τ=0.50): PPL ≈ 59.44, Violation = 0.0% |
|
|
| ### Molecular Generation (§5.2, Table 4) |
|
|
| ```bash |
| python -m cdd.experiments.molecular_generation \ |
| --sa_threshold 3.5 \ |
| --num_samples 1000 \ |
| --device cuda |
| ``` |
|
|
| ### Instruction Following (§5.3, Table 5) |
|
|
| ```bash |
| # Counting task |
| python -m cdd.experiments.instruction_following \ |
| --task counting --num_samples 100 |
| |
| # Lexical task |
| python -m cdd.experiments.instruction_following \ |
| --task lexical --num_samples 100 |
| ``` |
|
|
| ### Train Surrogate Models |
|
|
| ```bash |
| # Toxicity surrogate (GPT-Neo 1.3B on Jigsaw) |
| python -m cdd.experiments.train_surrogates --task toxicity |
| |
| # SA surrogate (GPT-2 on QM9) |
| python -m cdd.experiments.train_surrogates --task sa |
| ``` |
|
|
| ## ALM Hyperparameters (Appendix B & C) |
|
|
| | Task | λ₀ | μ₀ | μ_max | K (outer) | M (inner) | η | |
| |------|-----|-----|-------|-----------|-----------|------| |
| | Toxicity (NLP) | 0.0 | 1.0 | — | 1000 | 10 | 0.20 | |
| | Molecular (QM9) | 0.0 | 1.0 | 1000 | 1000 | 100 | 1.0 | |
| |
| ## Base Models |
| |
| | Model | Type | Params | Vocab | Length | Trained On | |
| |-------|------|--------|-------|--------|------------| |
| | [kuleshov-group/mdlm-owt](https://huggingface.co/kuleshov-group/mdlm-owt) | MDLM | 110M | 50,258 | 1024 | OpenWebText | |
| | [kuleshov-group/udlm-qm9](https://huggingface.co/kuleshov-group/udlm-qm9) | UDLM | 92M | 40 | 32 | QM9 | |
| | [kuleshov-group/udlm-lm1b](https://huggingface.co/kuleshov-group/udlm-lm1b) | UDLM | 139M | 30,522 | 128 | LM1B | |
| |
| ## Method Details |
| |
| ### Gumbel-Softmax Relaxation (Eq. 6) |
| |
| Makes `argmax` differentiable: |
| ``` |
| φ̃(x)(v) = exp((log x(v) + ξ_v) / T) / Σ exp((log x(v') + ξ_v') / T) |
| ``` |
| |
| ### Augmented Lagrangian (Eq. 7-8) |
| |
| Optimization at each diffusion step: |
| ``` |
| L(y, λ, μ) = D_KL(x_θ || y) + λ·g(ỹ) + (μ/2)·g(ỹ)² |
| ``` |
| |
| Update rules: |
| ``` |
| λ ← λ + μ·g(ỹ) |
| μ ← min(2μ, μ_max) |
| ``` |
| |
| ### MDLM vs UDLM for CDD |
| |
| | Property | MDLM | UDLM | |
| |----------|------|------| |
| | Noise type | Absorbing ([MASK]) | Uniform | |
| | Tokens re-sampled each step | Only masked | All | |
| | CDD paper usage | Text tasks | Molecular | |
| | Time conditioning | No | Yes | |
| |
| ## Citation |
| |
| ```bibtex |
| @article{cardei2025constrained, |
| title={Constrained Language Generation with Discrete Diffusion Models}, |
| author={Cardei, Michael and Christopher, Jacob K and Hartvigsen, Thomas |
| and Bartoldson, Brian R. and Kailkhura, Bhavya and Fioretto, Ferdinando}, |
| journal={arXiv preprint arXiv:2503.09790}, |
| year={2025} |
| } |
| ``` |
| |