syedmohaiminulhoque's picture
Complete CDD implementation: Constrained Discrete Diffusion (arXiv:2503.09790v3)
2d0a056 verified
---
title: CDD - Constrained Discrete Diffusion
emoji: 🧬
colorFrom: green
colorTo: blue
sdk: docker
app_port: 7860
pinned: false
---
# Constrained Discrete Diffusion (CDD)
Implementation of **"Constrained Language Generation with Discrete Diffusion Models"** (Cardei et al., 2025 — [arXiv:2503.09790v3](https://arxiv.org/abs/2503.09790v3)).
CDD integrates discrete diffusion models with differentiable optimization to enforce constraints on text generation **without retraining**. It achieves **zero constraint violations** across toxicity mitigation, molecular generation, and instruction following.
## Key Idea
Discrete diffusion models (MDLM, UDLM) generate sequences by iteratively denoising from a noise distribution. At each reverse step, CDD inserts an **Augmented Lagrangian Method (ALM) projection** that modifies the denoiser's predicted token distributions to satisfy user-defined constraints:
```
z_T → x_θ(z_T) → [CDD Projection] → z_{T-1} → ... → z_0
```
The projection solves:
```
min_y D_KL(x_θ || y) s.t. g(argmax(y)) ∈ C
```
using Gumbel-Softmax relaxation for differentiability and ALM for constraint enforcement.
## Architecture
```
cdd/
├── samplers/
│ └── cdd_sampler.py # Core ALM projection + sampling loop (§4)
├── constraints/
│ ├── toxicity.py # Toxicity mitigation constraint (§5.1)
│ ├── molecular.py # SA + Novelty constraints (§5.2)
│ └── instruction.py # Counting + Lexical constraints (§5.3)
├── models/
│ └── load_models.py # MDLM/UDLM model loading utilities
├── experiments/
│ ├── toxicity_mitigation.py # Full toxicity experiment
│ ├── molecular_generation.py # Full molecular experiment
│ ├── instruction_following.py # Full instruction experiment
│ └── train_surrogates.py # Surrogate model training
└── utils/
├── noise_schedule.py # Diffusion math (schedules, posteriors)
└── evaluation.py # Metrics (PPL, entropy, violation rate)
```
## Installation
```bash
pip install -e .
# For GPU models (required for MDLM/UDLM inference):
pip install flash-attn --no-build-isolation
# For molecular evaluation:
pip install rdkit-pypi
```
## Quick Start
### 1. Unconstrained Generation with MDLM
```python
import torch
from cdd.models import load_mdlm
from cdd.samplers import CDDSampler, ALMConfig
# Load pretrained MDLM
model, tokenizer, info = load_mdlm(device="cuda")
# Create sampler (no constraints)
sampler = CDDSampler(
model=model,
tokenizer=tokenizer,
constraint_fn=lambda x: torch.tensor(0.0), # No constraint
diffusion_type="mdlm",
num_timesteps=1000,
seq_length=128,
device="cuda",
)
result = sampler.sample(batch_size=1)
print(result["text"][0])
```
### 2. Toxicity-Constrained Generation
```python
from cdd.samplers import CDDSampler, ALMConfig, TOXICITY_ALM_CONFIG
# Setup toxicity constraint (§5.1)
from cdd.constraints import create_toxicity_constraint
constraint = create_toxicity_constraint(threshold=0.5)
sampler = CDDSampler(
model=model,
tokenizer=tokenizer,
constraint_fn=constraint,
alm_config=TOXICITY_ALM_CONFIG, # λ=0, μ=1, K=1000, M=10, η=0.2
diffusion_type="mdlm",
device="cuda",
)
# Conditional generation from toxic prompt
prefix = tokenizer.encode("The politician was accused of", return_tensors="pt").to("cuda")
result = sampler.sample(batch_size=1, prefix_ids=prefix)
print(result["text"][0]) # Non-toxic completion
```
### 3. Molecular Generation with SA Constraint
```python
from cdd.models import load_udlm
from cdd.samplers import CDDSampler, MOLECULAR_SA_ALM_CONFIG
model, tokenizer, info = load_udlm("kuleshov-group/udlm-qm9", device="cuda")
sampler = CDDSampler(
model=model,
tokenizer=tokenizer,
constraint_fn=sa_constraint, # SA ≤ 3.5
alm_config=MOLECULAR_SA_ALM_CONFIG, # λ=0, μ=1, μ_max=1000, K=1000, M=100, η=1.0
diffusion_type="udlm",
num_timesteps=1000,
seq_length=32,
device="cuda",
)
result = sampler.sample(batch_size=32)
```
## Reproducing Paper Experiments
### Toxicity Mitigation (§5.1, Table 3)
```bash
python -m cdd.experiments.toxicity_mitigation \
--threshold 0.5 \
--num_samples 1000 \
--device cuda
```
Expected results (CDD τ=0.50): PPL ≈ 59.44, Violation = 0.0%
### Molecular Generation (§5.2, Table 4)
```bash
python -m cdd.experiments.molecular_generation \
--sa_threshold 3.5 \
--num_samples 1000 \
--device cuda
```
### Instruction Following (§5.3, Table 5)
```bash
# Counting task
python -m cdd.experiments.instruction_following \
--task counting --num_samples 100
# Lexical task
python -m cdd.experiments.instruction_following \
--task lexical --num_samples 100
```
### Train Surrogate Models
```bash
# Toxicity surrogate (GPT-Neo 1.3B on Jigsaw)
python -m cdd.experiments.train_surrogates --task toxicity
# SA surrogate (GPT-2 on QM9)
python -m cdd.experiments.train_surrogates --task sa
```
## ALM Hyperparameters (Appendix B & C)
| Task | λ₀ | μ₀ | μ_max | K (outer) | M (inner) | η |
|------|-----|-----|-------|-----------|-----------|------|
| Toxicity (NLP) | 0.0 | 1.0 | — | 1000 | 10 | 0.20 |
| Molecular (QM9) | 0.0 | 1.0 | 1000 | 1000 | 100 | 1.0 |
## Base Models
| Model | Type | Params | Vocab | Length | Trained On |
|-------|------|--------|-------|--------|------------|
| [kuleshov-group/mdlm-owt](https://huggingface.co/kuleshov-group/mdlm-owt) | MDLM | 110M | 50,258 | 1024 | OpenWebText |
| [kuleshov-group/udlm-qm9](https://huggingface.co/kuleshov-group/udlm-qm9) | UDLM | 92M | 40 | 32 | QM9 |
| [kuleshov-group/udlm-lm1b](https://huggingface.co/kuleshov-group/udlm-lm1b) | UDLM | 139M | 30,522 | 128 | LM1B |
## Method Details
### Gumbel-Softmax Relaxation (Eq. 6)
Makes `argmax` differentiable:
```
φ̃(x)(v) = exp((log x(v) + ξ_v) / T) / Σ exp((log x(v') + ξ_v') / T)
```
### Augmented Lagrangian (Eq. 7-8)
Optimization at each diffusion step:
```
L(y, λ, μ) = D_KL(x_θ || y) + λ·g(ỹ) + (μ/2)·g(ỹ)²
```
Update rules:
```
λ ← λ + μ·g(ỹ)
μ ← min(2μ, μ_max)
```
### MDLM vs UDLM for CDD
| Property | MDLM | UDLM |
|----------|------|------|
| Noise type | Absorbing ([MASK]) | Uniform |
| Tokens re-sampled each step | Only masked | All |
| CDD paper usage | Text tasks | Molecular |
| Time conditioning | No | Yes |
## Citation
```bibtex
@article{cardei2025constrained,
title={Constrained Language Generation with Discrete Diffusion Models},
author={Cardei, Michael and Christopher, Jacob K and Hartvigsen, Thomas
and Bartoldson, Brian R. and Kailkhura, Bhavya and Fioretto, Ferdinando},
journal={arXiv preprint arXiv:2503.09790},
year={2025}
}
```