dataopsnick's picture
Improve readability of README.md by spinning off the training and inference examples as files in the repo (#1)
100dcfd
|
Raw
History Blame Contribute Delete
1.72 kB
---
license: mit
tags:
- text-generation
- bidirectional
- diffusion
- speculative-decoding
- adapt-diff
---
# ADAPT-DIFF: Qwen 3.5 0.8B Bidirectional Latent Diffusion Model
This is the public release of **ADAPT-DIFF** built over a bidirectional `Qwen/Qwen3.5-0.8B` backbone. It implements a two-stage hybrid generation framework:
1. **Parallel Latent Diffusion Initialization**: Generates a block of tokens in parallel via custom LDM heads.
2. **Logit Uncertainty Refinement**: Uses a dynamic entropy router and an Actor-Critic tree-search to refine uncertain tokens at high-precision bfloat16.
### How to Load the Weights
Because this model utilizes custom architectures, you must define the `A2DQwenLMHeadModel` and `StackedLDMHeads` classes in your script, then load the weights as follows:
```python
import torch
import transformers
from huggingface_hub import hf_hub_download
# 1. Initialize and load the bidirectional base LLM
base_model = transformers.AutoModel.from_pretrained("dataopsnick/adapt-diff-qwen-0.8b", torch_dtype=torch.bfloat16)
# 2. Download and load the custom LDM projection head weights
ldm_weights_path = hf_hub_download(repo_id="dataopsnick/adapt-diff-qwen-0.8b", filename="ldm_heads.pt")
ldm_heads.load_state_dict(torch.load(ldm_weights_path))
```
### Full Inference Benchmarks & SFT Calibration
To run the complete benchmark comparison against the autoregressive baseline or to perform Supervised Fine-Tuning (SFT) calibration on your own system, clone this repository and execute the dedicated scripts included in the repository:
#### 1. Run Comparative Benchmarking (GSM8K & MBPP)
```bash
python infer.py
```
#### 2. Run Head Alignment & SFT Training
```bash
python train.py
```