File size: 4,419 Bytes
bfa9a3d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | """
Model loading script
Load Fairy2i-W2 model from Hugging Face repository.
Usage:
from load_model import load_model
model, tokenizer = load_model()
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import os
import sys
# Add current directory to path for importing qat_modules
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
from qat_modules import replace_modules_for_qat
def load_model(
device="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.bfloat16
):
"""
Load Fairy2i-W2 model: standard architecture + custom weights + QAT linear layer replacement
Load weights and tokenizer from Hugging Face repository.
Args:
device: Device, default auto-select
torch_dtype: Data type, default torch.bfloat16
Returns:
model, tokenizer
"""
# Configuration parameters
base_model_id = "meta-llama/Llama-2-7b-hf"
weights_repo_id = "PKU-DS-LAB/Fairy2i-W2"
quant_method = "complex_phase_v2"
skip_lm_head = False
print("=" * 70)
print("Loading Fairy2i-W2 Model")
print("=" * 70)
# Step 1: Load standard model architecture
print(f"\n📥 Step 1/4: Loading standard model architecture: {base_model_id}")
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch_dtype,
device_map=device,
trust_remote_code=False
)
print("✅ Standard model architecture loaded")
# Step 2: Load custom weights
print(f"\n💾 Step 2/4: Loading weights from Hugging Face repository: {weights_repo_id}")
# Check for sharded weights
try:
index_path = hf_hub_download(
repo_id=weights_repo_id,
filename="model.safetensors.index.json",
local_dir=None
)
# Sharded weights
from safetensors import safe_open
import json
with open(index_path, 'r') as f:
weight_map = json.load(f)["weight_map"]
state_dict = {}
for weight_file in set(weight_map.values()):
file_path = hf_hub_download(
repo_id=weights_repo_id,
filename=weight_file,
local_dir=None
)
with safe_open(file_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
model.load_state_dict(state_dict, strict=False)
print(f"✅ Weights loaded (sharded)")
except Exception:
# Single weight file
try:
weights_path = hf_hub_download(
repo_id=weights_repo_id,
filename="model.safetensors",
local_dir=None
)
state_dict = load_file(weights_path)
model.load_state_dict(state_dict, strict=False)
print(f"✅ Weights loaded (single file)")
except Exception as e:
raise RuntimeError(f"Failed to load weights from Hugging Face: {e}")
# Step 3: Apply QAT replacement
print(f"\n🔧 Step 3/4: Applying QAT replacement ({quant_method})...")
replace_modules_for_qat(model, quant_method, skip_lm_head=skip_lm_head)
print("✅ QAT replacement completed")
# Step 4: Load tokenizer
print(f"\n📝 Step 4/4: Loading Tokenizer from Hugging Face repository: {weights_repo_id}")
tokenizer = AutoTokenizer.from_pretrained(weights_repo_id)
print("✅ Tokenizer loaded")
print("\n" + "=" * 70)
print("✅ Model loading completed!")
print("=" * 70)
return model, tokenizer
if __name__ == "__main__":
# Example: Load model
model, tokenizer = load_model()
# Test generation
print("\n🧪 Testing generation...")
prompt = "Hello, how are you?"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
do_sample=True,
temperature=0.7
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Prompt: {prompt}")
print(f"Response: {response}")
|