File size: 4,419 Bytes
bfa9a3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Model loading script

Load Fairy2i-W2 model from Hugging Face repository.

Usage:
    from load_model import load_model
    model, tokenizer = load_model()
"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import os
import sys

# Add current directory to path for importing qat_modules
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)

from qat_modules import replace_modules_for_qat


def load_model(
    device="cuda" if torch.cuda.is_available() else "cpu",
    torch_dtype=torch.bfloat16
):
    """
    Load Fairy2i-W2 model: standard architecture + custom weights + QAT linear layer replacement
    
    Load weights and tokenizer from Hugging Face repository.
    
    Args:
        device: Device, default auto-select
        torch_dtype: Data type, default torch.bfloat16
    
    Returns:
        model, tokenizer
    """
    # Configuration parameters
    base_model_id = "meta-llama/Llama-2-7b-hf"
    weights_repo_id = "PKU-DS-LAB/Fairy2i-W2"
    quant_method = "complex_phase_v2"
    skip_lm_head = False
    print("=" * 70)
    print("Loading Fairy2i-W2 Model")
    print("=" * 70)
    
    # Step 1: Load standard model architecture
    print(f"\n📥 Step 1/4: Loading standard model architecture: {base_model_id}")
    model = AutoModelForCausalLM.from_pretrained(
        base_model_id,
        torch_dtype=torch_dtype,
        device_map=device,
        trust_remote_code=False
    )
    print("✅ Standard model architecture loaded")
    
    # Step 2: Load custom weights
    print(f"\n💾 Step 2/4: Loading weights from Hugging Face repository: {weights_repo_id}")
    
    # Check for sharded weights
    try:
        index_path = hf_hub_download(
            repo_id=weights_repo_id,
            filename="model.safetensors.index.json",
            local_dir=None
        )
        
        # Sharded weights
        from safetensors import safe_open
        import json
        
        with open(index_path, 'r') as f:
            weight_map = json.load(f)["weight_map"]
        
        state_dict = {}
        for weight_file in set(weight_map.values()):
            file_path = hf_hub_download(
                repo_id=weights_repo_id,
                filename=weight_file,
                local_dir=None
            )
            with safe_open(file_path, framework="pt", device="cpu") as f:
                for key in f.keys():
                    state_dict[key] = f.get_tensor(key)
        
        model.load_state_dict(state_dict, strict=False)
        print(f"✅ Weights loaded (sharded)")
    except Exception:
        # Single weight file
        try:
            weights_path = hf_hub_download(
                repo_id=weights_repo_id,
                filename="model.safetensors",
                local_dir=None
            )
            state_dict = load_file(weights_path)
            model.load_state_dict(state_dict, strict=False)
            print(f"✅ Weights loaded (single file)")
        except Exception as e:
            raise RuntimeError(f"Failed to load weights from Hugging Face: {e}")
    
    # Step 3: Apply QAT replacement
    print(f"\n🔧 Step 3/4: Applying QAT replacement ({quant_method})...")
    replace_modules_for_qat(model, quant_method, skip_lm_head=skip_lm_head)
    print("✅ QAT replacement completed")
    
    # Step 4: Load tokenizer
    print(f"\n📝 Step 4/4: Loading Tokenizer from Hugging Face repository: {weights_repo_id}")
    tokenizer = AutoTokenizer.from_pretrained(weights_repo_id)
    print("✅ Tokenizer loaded")
    
    print("\n" + "=" * 70)
    print("✅ Model loading completed!")
    print("=" * 70)
    
    return model, tokenizer


if __name__ == "__main__":
    # Example: Load model
    model, tokenizer = load_model()
    
    # Test generation
    print("\n🧪 Testing generation...")
    prompt = "Hello, how are you?"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,
            do_sample=True,
            temperature=0.7
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Prompt: {prompt}")
    print(f"Response: {response}")