| """ |
| Entrypoint for loading the CODI model from this repository. |
| |
| Usage: |
| from huggingface_hub import snapshot_download |
| local_dir = snapshot_download("YOUR_USERNAME/codi-gpt2-prontoqa-latent") |
| |
| import sys |
| sys.path.insert(0, local_dir) |
| from load_model import load_codi_model |
| |
| model = load_codi_model(local_dir, device="cuda") |
| """ |
| import os |
| import torch |
| from huggingface_hub import snapshot_download |
| from model import CODI, ModelArguments, TrainingArguments |
| from peft import LoraConfig |
|
|
|
|
| def load_codi_model(repo_id_or_path, device="cuda", dtype=torch.float16): |
| """ |
| Load a CODI model from a HuggingFace repo or local directory. |
| |
| Args: |
| repo_id_or_path: HF repo id (e.g. "user/repo") or local directory path. |
| device: Device to load the model on. |
| dtype: Data type for the model weights. |
| |
| Returns: |
| CODI model with loaded weights, in eval mode. |
| """ |
| |
| if os.path.isdir(repo_id_or_path): |
| local_dir = repo_id_or_path |
| else: |
| print(f"Downloading from {repo_id_or_path}...") |
| local_dir = snapshot_download(repo_id=repo_id_or_path) |
|
|
| weights_path = os.path.join(local_dir, "pytorch_model.bin") |
|
|
| |
| model_args = ModelArguments( |
| model_name_or_path="gpt2", |
| train=False, |
| full_precision=True, |
| ) |
|
|
| training_args = TrainingArguments( |
| output_dir="./tmp", |
| num_latent=5, |
| use_lora=True, |
| use_prj=False, |
| bf16=False, |
| fix_attn_mask=False, |
| print_loss=False, |
| distill_loss_type="smooth_l1", |
| distill_loss_factor=1.0, |
| ref_loss_factor=1.0, |
| ) |
|
|
| lora_config = LoraConfig( |
| r=128, |
| lora_alpha=16, |
| lora_dropout=0.05, |
| target_modules=["c_attn", "c_proj", "c_fc"], |
| ) |
|
|
| |
| model = CODI(model_args, training_args, lora_config) |
|
|
| print(f"Loading weights from {weights_path}...") |
| state_dict = torch.load(weights_path, map_location="cpu") |
| model.load_state_dict(state_dict) |
|
|
| model = model.to(device=device, dtype=dtype) |
| model.eval() |
| print("Model loaded successfully.") |
| return model |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
| repo = sys.argv[1] if len(sys.argv) > 1 else "." |
| model = load_codi_model(repo) |
| print(f"Model type: {type(model).__name__}") |
| print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|