Fairy2i-W2 / load_model.py
Lab1806's picture
Upload folder using huggingface_hub
bfa9a3d verified
"""
Model loading script
Load Fairy2i-W2 model from Hugging Face repository.
Usage:
from load_model import load_model
model, tokenizer = load_model()
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import os
import sys
# Add current directory to path for importing qat_modules
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
from qat_modules import replace_modules_for_qat
def load_model(
device="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.bfloat16
):
"""
Load Fairy2i-W2 model: standard architecture + custom weights + QAT linear layer replacement
Load weights and tokenizer from Hugging Face repository.
Args:
device: Device, default auto-select
torch_dtype: Data type, default torch.bfloat16
Returns:
model, tokenizer
"""
# Configuration parameters
base_model_id = "meta-llama/Llama-2-7b-hf"
weights_repo_id = "PKU-DS-LAB/Fairy2i-W2"
quant_method = "complex_phase_v2"
skip_lm_head = False
print("=" * 70)
print("Loading Fairy2i-W2 Model")
print("=" * 70)
# Step 1: Load standard model architecture
print(f"\n📥 Step 1/4: Loading standard model architecture: {base_model_id}")
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch_dtype,
device_map=device,
trust_remote_code=False
)
print("✅ Standard model architecture loaded")
# Step 2: Load custom weights
print(f"\n💾 Step 2/4: Loading weights from Hugging Face repository: {weights_repo_id}")
# Check for sharded weights
try:
index_path = hf_hub_download(
repo_id=weights_repo_id,
filename="model.safetensors.index.json",
local_dir=None
)
# Sharded weights
from safetensors import safe_open
import json
with open(index_path, 'r') as f:
weight_map = json.load(f)["weight_map"]
state_dict = {}
for weight_file in set(weight_map.values()):
file_path = hf_hub_download(
repo_id=weights_repo_id,
filename=weight_file,
local_dir=None
)
with safe_open(file_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
model.load_state_dict(state_dict, strict=False)
print(f"✅ Weights loaded (sharded)")
except Exception:
# Single weight file
try:
weights_path = hf_hub_download(
repo_id=weights_repo_id,
filename="model.safetensors",
local_dir=None
)
state_dict = load_file(weights_path)
model.load_state_dict(state_dict, strict=False)
print(f"✅ Weights loaded (single file)")
except Exception as e:
raise RuntimeError(f"Failed to load weights from Hugging Face: {e}")
# Step 3: Apply QAT replacement
print(f"\n🔧 Step 3/4: Applying QAT replacement ({quant_method})...")
replace_modules_for_qat(model, quant_method, skip_lm_head=skip_lm_head)
print("✅ QAT replacement completed")
# Step 4: Load tokenizer
print(f"\n📝 Step 4/4: Loading Tokenizer from Hugging Face repository: {weights_repo_id}")
tokenizer = AutoTokenizer.from_pretrained(weights_repo_id)
print("✅ Tokenizer loaded")
print("\n" + "=" * 70)
print("✅ Model loading completed!")
print("=" * 70)
return model, tokenizer
if __name__ == "__main__":
# Example: Load model
model, tokenizer = load_model()
# Test generation
print("\n🧪 Testing generation...")
prompt = "Hello, how are you?"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
do_sample=True,
temperature=0.7
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Prompt: {prompt}")
print(f"Response: {response}")