""" Entrypoint for loading the CODI model from this repository. Usage: from huggingface_hub import snapshot_download local_dir = snapshot_download("YOUR_USERNAME/codi-gpt2-prontoqa-latent") import sys sys.path.insert(0, local_dir) from load_model import load_codi_model model = load_codi_model(local_dir, device="cuda") """ import os import torch from huggingface_hub import snapshot_download from model import CODI, ModelArguments, TrainingArguments from peft import LoraConfig def load_codi_model(repo_id_or_path, device="cuda", dtype=torch.float16): """ Load a CODI model from a HuggingFace repo or local directory. Args: repo_id_or_path: HF repo id (e.g. "user/repo") or local directory path. device: Device to load the model on. dtype: Data type for the model weights. Returns: CODI model with loaded weights, in eval mode. """ # Download if needed if os.path.isdir(repo_id_or_path): local_dir = repo_id_or_path else: print(f"Downloading from {repo_id_or_path}...") local_dir = snapshot_download(repo_id=repo_id_or_path) weights_path = os.path.join(local_dir, "pytorch_model.bin") # Reconstruct the model with the same args used during training model_args = ModelArguments( model_name_or_path="gpt2", train=False, full_precision=True, ) training_args = TrainingArguments( output_dir="./tmp", num_latent=5, use_lora=True, use_prj=False, bf16=False, fix_attn_mask=False, print_loss=False, distill_loss_type="smooth_l1", distill_loss_factor=1.0, ref_loss_factor=1.0, ) lora_config = LoraConfig( r=128, lora_alpha=16, lora_dropout=0.05, target_modules=["c_attn", "c_proj", "c_fc"], # GPT-2 attention modules ) # Build the model skeleton, then load trained weights model = CODI(model_args, training_args, lora_config) print(f"Loading weights from {weights_path}...") state_dict = torch.load(weights_path, map_location="cpu") model.load_state_dict(state_dict) model = model.to(device=device, dtype=dtype) model.eval() print("Model loaded successfully.") return model if __name__ == "__main__": import sys repo = sys.argv[1] if len(sys.argv) > 1 else "." model = load_codi_model(repo) print(f"Model type: {type(model).__name__}") print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")