File size: 2,550 Bytes
17bde88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Entrypoint for loading the CODI model from this repository.

Usage:
    from huggingface_hub import snapshot_download
    local_dir = snapshot_download("YOUR_USERNAME/codi-gpt2-prontoqa-latent")

    import sys
    sys.path.insert(0, local_dir)
    from load_model import load_codi_model

    model = load_codi_model(local_dir, device="cuda")
"""
import os
import torch
from huggingface_hub import snapshot_download
from model import CODI, ModelArguments, TrainingArguments
from peft import LoraConfig


def load_codi_model(repo_id_or_path, device="cuda", dtype=torch.float16):
    """
    Load a CODI model from a HuggingFace repo or local directory.

    Args:
        repo_id_or_path: HF repo id (e.g. "user/repo") or local directory path.
        device: Device to load the model on.
        dtype: Data type for the model weights.

    Returns:
        CODI model with loaded weights, in eval mode.
    """
    # Download if needed
    if os.path.isdir(repo_id_or_path):
        local_dir = repo_id_or_path
    else:
        print(f"Downloading from {repo_id_or_path}...")
        local_dir = snapshot_download(repo_id=repo_id_or_path)

    weights_path = os.path.join(local_dir, "pytorch_model.bin")

    # Reconstruct the model with the same args used during training
    model_args = ModelArguments(
        model_name_or_path="gpt2",
        train=False,
        full_precision=True,
    )

    training_args = TrainingArguments(
        output_dir="./tmp",
        num_latent=5,
        use_lora=True,
        use_prj=False,
        bf16=False,
        fix_attn_mask=False,
        print_loss=False,
        distill_loss_type="smooth_l1",
        distill_loss_factor=1.0,
        ref_loss_factor=1.0,
    )

    lora_config = LoraConfig(
        r=128,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=["c_attn", "c_proj", "c_fc"],  # GPT-2 attention modules
    )

    # Build the model skeleton, then load trained weights
    model = CODI(model_args, training_args, lora_config)

    print(f"Loading weights from {weights_path}...")
    state_dict = torch.load(weights_path, map_location="cpu")
    model.load_state_dict(state_dict)

    model = model.to(device=device, dtype=dtype)
    model.eval()
    print("Model loaded successfully.")
    return model


if __name__ == "__main__":
    import sys
    repo = sys.argv[1] if len(sys.argv) > 1 else "."
    model = load_codi_model(repo)
    print(f"Model type: {type(model).__name__}")
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")