|
|
--- |
|
|
license: apache-2.0 |
|
|
language: |
|
|
- en |
|
|
--- |
|
|
# LLaDA-346M: Large Language Diffusion with Masking |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This is a **346 Million parameter** Large Language Diffusion Model trained with masked diffusion processes. This model demonstrates that diffusion-based approaches can be viable alternatives to autoregressive language models. |
|
|
|
|
|
### Key Features |
|
|
- **Architecture**: Masked Diffusion Model (MDM) with Transformer encoder |
|
|
- **Parameters**: 346M |
|
|
- **Sequence Length**: 512 tokens |
|
|
- **Vocab Size**: 50,257 (GPT-2) |
|
|
- **Training Data**: 50,000 WikiText-2 samples |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
``` |
|
|
Token Embeddings (50257 Γ 1024) |
|
|
β |
|
|
Position Embeddings (512 Γ 1024) |
|
|
β |
|
|
Time Embeddings (MLP) |
|
|
β |
|
|
Transformer Encoder (12 layers, 16 heads) |
|
|
ββ Self-Attention |
|
|
ββ Feed-Forward (4096 dim) |
|
|
β |
|
|
Output Projection (1024 Γ 50257) |
|
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
|
|
- **Algorithm**: Masked Diffusion Model (MDM) |
|
|
- **Loss Function**: Cross-entropy on masked positions |
|
|
- **Optimizer**: AdamW (lr=3e-5, betas=(0.9, 0.95)) |
|
|
- **Batch Size**: 16 (effective: 32 with grad accumulation) |
|
|
- **Gradient Checkpointing**: Enabled |
|
|
- **Mixed Precision**: AMP (FP32/FP16) |
|
|
- **Epochs**: 4 |
|
|
- **Training Samples**: 50,000 |
|
|
- **GPU**: NVIDIA V100 (22GB VRAM) |
|
|
- **Training Time**: ~20 hours |
|
|
|
|
|
## Performance |
|
|
|
|
|
| Metric | Value | |
|
|
|--------|-------| |
|
|
| Initial Loss | 5.96 | |
|
|
| Final Loss | 4.94 | |
|
|
| Loss Reduction | 17.1% | |
|
|
| Total Parameters | 346M | |
|
|
| Model Size (FP32) | 1.38 GB | |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
pip install transformers torch |
|
|
``` |
|
|
|
|
|
### Loading the Model |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
from your_module import MaskedDiffusionModel |
|
|
|
|
|
# Load model |
|
|
model = MaskedDiffusionModel( |
|
|
vocab_size=50257, |
|
|
hidden_dim=1024, |
|
|
num_layers=12, |
|
|
num_heads=16, |
|
|
ff_dim=4096, |
|
|
dropout=0.1, |
|
|
max_seq_length=512, |
|
|
num_timesteps=100 |
|
|
) |
|
|
|
|
|
# Load weights |
|
|
checkpoint = torch.load("pytorch_model.bin") |
|
|
model.load_state_dict(checkpoint) |
|
|
model.eval() |
|
|
|
|
|
# Load tokenizer |
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
``` |
|
|
|
|
|
### Text Generation |
|
|
|
|
|
```python |
|
|
from diffusion_sampler import DiffusionSampler |
|
|
|
|
|
sampler = DiffusionSampler(model, tokenizer, config, device) |
|
|
|
|
|
# Generate text |
|
|
text = sampler.generate( |
|
|
prompt="The future of AI", |
|
|
num_steps=40, |
|
|
temperature=0.8, |
|
|
top_p=0.9 |
|
|
) |
|
|
print(text) |
|
|
``` |
|
|
|
|
|
## Model Characteristics |
|
|
|
|
|
### Advantages |
|
|
β
**Bidirectional Context**: Sees full context unlike autoregressive models |
|
|
β
**Parallel Generation**: Can predict multiple tokens simultaneously |
|
|
β
**Reversal Invariance**: Equal performance on forward and reverse tasks |
|
|
β
**Global Coherence**: Reduces error accumulation |
|
|
|
|
|
### Limitations |
|
|
β Slower generation (iterative denoising process) |
|
|
β Requires more compute for inference |
|
|
β Not fine-tuned for specific tasks |
|
|
|
|
|
## Training Process |
|
|
|
|
|
### Forward Process |
|
|
- Gradually mask tokens randomly |
|
|
- At timestep t β [0,1], each token masked with probability t |
|
|
- Creates noisy version of input |
|
|
|
|
|
### Reverse Process |
|
|
- Iteratively predict and unmask tokens |
|
|
- Uses transformer to predict masked positions |
|
|
- Trained with cross-entropy loss on masked tokens only |
|
|
|
|
|
## Optimization Techniques |
|
|
|
|
|
- **Gradient Checkpointing**: Save memory during backprop |
|
|
- **Mixed Precision (AMP)**: Use FP16 where possible |
|
|
- **Gradient Accumulation**: Simulate larger batches |
|
|
- **Layer Norm First**: Improved training stability |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
```bibtex |
|
|
@article{nie2025llada, |
|
|
title={Large Language Diffusion Models}, |
|
|
author={Nie, Shen and others}, |
|
|
journal={arXiv preprint arXiv:2502.09992}, |
|
|
year={2025} |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
MIT License - Feel free to use for research and commercial purposes |
|
|
|
|
|
## Acknowledgments |
|
|
|
|
|
- Based on "Large Language Diffusion Models" (Nie et al., 2025) |
|
|
- Built with PyTorch and Transformers |
|
|
- Trained on WikiText-2 dataset |
|
|
- Inspired by diffusion models for vision (DiT, Genie) |
|
|
|
|
|
## Contact & Support |
|
|
|
|
|
For issues, questions, or suggestions, please open an issue on GitHub or contact the model author. |
|
|
|
|
|
--- |
|
|
|
|
|
**Last Updated**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} |