CODI โ GPT-2 ProntoQA (Latent Reasoning)
A CODI (Chain-of-thought Distillation) model trained on ProntoQA for latent chain-of-thought reasoning. The model wraps GPT-2 with LoRA adapters and a distillation objective that compresses explicit chain-of-thought steps into latent embeddings.
Quick Start
pip install transformers peft huggingface_hub torch
from huggingface_hub import snapshot_download
import sys
# Download the model
local_dir = snapshot_download("simon-pltk/codi-gpt2-prontoqa-latent")
# Load it
sys.path.insert(0, local_dir)
from load_model import load_codi_model
model = load_codi_model(local_dir, device="cuda")
Architecture
CODI is a custom torch.nn.Module that:
- Wraps a base GPT-2 model loaded via
AutoModelForCausalLM - Applies LoRA adapters (rank=128, alpha=16) for parameter-efficient tuning
- Generates latent embeddings that replace explicit chain-of-thought tokens
- Uses a layer-wise distillation loss (SmoothL1) to align the student (latent) representations with a teacher (explicit CoT) across all layers
Training Details
| Parameter | Value |
|---|---|
| Base model | GPT-2 (124M) |
| Dataset | ProntoQA |
| Epochs | 50 |
| Learning rate | 0.003 |
| Seed | 11 |
| Num latent tokens | 5 |
| LoRA rank | 128 |
| Distill loss | SmoothL1 |
Final Metrics
| Metric | Start | End |
|---|---|---|
| CE Loss | 6.6610 | 0.1202 |
| Distill Loss | 0.2742 | 0.0759 |
| Ref CE Loss | 1.5432 | 0.0113 |
| Total Loss | 8.4784 | 0.1931 |
Files
| File | Description |
|---|---|
pytorch_model.bin |
Full CODI state dict (base model + LoRA + projection) |
model.py |
CODI class definition and dataclass configs |
load_model.py |
Entrypoint โ helper to reconstruct and load the model |
codi_config.json |
Model metadata and training hyperparameters |
training_args.bin |
Original HuggingFace TrainingArguments |
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐ Ask for provider support