CODI โ€” GPT-2 ProntoQA (Latent Reasoning)

A CODI (Chain-of-thought Distillation) model trained on ProntoQA for latent chain-of-thought reasoning. The model wraps GPT-2 with LoRA adapters and a distillation objective that compresses explicit chain-of-thought steps into latent embeddings.

Quick Start

pip install transformers peft huggingface_hub torch
from huggingface_hub import snapshot_download
import sys

# Download the model
local_dir = snapshot_download("simon-pltk/codi-gpt2-prontoqa-latent")

# Load it
sys.path.insert(0, local_dir)
from load_model import load_codi_model

model = load_codi_model(local_dir, device="cuda")

Architecture

CODI is a custom torch.nn.Module that:

  1. Wraps a base GPT-2 model loaded via AutoModelForCausalLM
  2. Applies LoRA adapters (rank=128, alpha=16) for parameter-efficient tuning
  3. Generates latent embeddings that replace explicit chain-of-thought tokens
  4. Uses a layer-wise distillation loss (SmoothL1) to align the student (latent) representations with a teacher (explicit CoT) across all layers

Training Details

Parameter Value
Base model GPT-2 (124M)
Dataset ProntoQA
Epochs 50
Learning rate 0.003
Seed 11
Num latent tokens 5
LoRA rank 128
Distill loss SmoothL1

Final Metrics

Metric Start End
CE Loss 6.6610 0.1202
Distill Loss 0.2742 0.0759
Ref CE Loss 1.5432 0.0113
Total Loss 8.4784 0.1931

Files

File Description
pytorch_model.bin Full CODI state dict (base model + LoRA + projection)
model.py CODI class definition and dataclass configs
load_model.py Entrypoint โ€” helper to reconstruct and load the model
codi_config.json Model metadata and training hyperparameters
training_args.bin Original HuggingFace TrainingArguments
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support