codi-gpt2-prontoqa-latent / load_model.py
simon-pltk's picture
Upload folder using huggingface_hub
17bde88 verified
"""
Entrypoint for loading the CODI model from this repository.
Usage:
from huggingface_hub import snapshot_download
local_dir = snapshot_download("YOUR_USERNAME/codi-gpt2-prontoqa-latent")
import sys
sys.path.insert(0, local_dir)
from load_model import load_codi_model
model = load_codi_model(local_dir, device="cuda")
"""
import os
import torch
from huggingface_hub import snapshot_download
from model import CODI, ModelArguments, TrainingArguments
from peft import LoraConfig
def load_codi_model(repo_id_or_path, device="cuda", dtype=torch.float16):
"""
Load a CODI model from a HuggingFace repo or local directory.
Args:
repo_id_or_path: HF repo id (e.g. "user/repo") or local directory path.
device: Device to load the model on.
dtype: Data type for the model weights.
Returns:
CODI model with loaded weights, in eval mode.
"""
# Download if needed
if os.path.isdir(repo_id_or_path):
local_dir = repo_id_or_path
else:
print(f"Downloading from {repo_id_or_path}...")
local_dir = snapshot_download(repo_id=repo_id_or_path)
weights_path = os.path.join(local_dir, "pytorch_model.bin")
# Reconstruct the model with the same args used during training
model_args = ModelArguments(
model_name_or_path="gpt2",
train=False,
full_precision=True,
)
training_args = TrainingArguments(
output_dir="./tmp",
num_latent=5,
use_lora=True,
use_prj=False,
bf16=False,
fix_attn_mask=False,
print_loss=False,
distill_loss_type="smooth_l1",
distill_loss_factor=1.0,
ref_loss_factor=1.0,
)
lora_config = LoraConfig(
r=128,
lora_alpha=16,
lora_dropout=0.05,
target_modules=["c_attn", "c_proj", "c_fc"], # GPT-2 attention modules
)
# Build the model skeleton, then load trained weights
model = CODI(model_args, training_args, lora_config)
print(f"Loading weights from {weights_path}...")
state_dict = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state_dict)
model = model.to(device=device, dtype=dtype)
model.eval()
print("Model loaded successfully.")
return model
if __name__ == "__main__":
import sys
repo = sys.argv[1] if len(sys.argv) > 1 else "."
model = load_codi_model(repo)
print(f"Model type: {type(model).__name__}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")