|
|
""" |
|
|
Model loading script |
|
|
|
|
|
Load Fairy2i-W2 model from Hugging Face repository. |
|
|
|
|
|
Usage: |
|
|
from load_model import load_model |
|
|
model, tokenizer = load_model() |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from safetensors.torch import load_file |
|
|
from huggingface_hub import hf_hub_download |
|
|
import os |
|
|
import sys |
|
|
|
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
if current_dir not in sys.path: |
|
|
sys.path.insert(0, current_dir) |
|
|
|
|
|
from qat_modules import replace_modules_for_qat |
|
|
|
|
|
|
|
|
def load_model( |
|
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
|
torch_dtype=torch.bfloat16 |
|
|
): |
|
|
""" |
|
|
Load Fairy2i-W2 model: standard architecture + custom weights + QAT linear layer replacement |
|
|
|
|
|
Load weights and tokenizer from Hugging Face repository. |
|
|
|
|
|
Args: |
|
|
device: Device, default auto-select |
|
|
torch_dtype: Data type, default torch.bfloat16 |
|
|
|
|
|
Returns: |
|
|
model, tokenizer |
|
|
""" |
|
|
|
|
|
base_model_id = "meta-llama/Llama-2-7b-hf" |
|
|
weights_repo_id = "PKU-DS-LAB/Fairy2i-W2" |
|
|
quant_method = "complex_phase_v2" |
|
|
skip_lm_head = False |
|
|
print("=" * 70) |
|
|
print("Loading Fairy2i-W2 Model") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
print(f"\n📥 Step 1/4: Loading standard model architecture: {base_model_id}") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
base_model_id, |
|
|
torch_dtype=torch_dtype, |
|
|
device_map=device, |
|
|
trust_remote_code=False |
|
|
) |
|
|
print("✅ Standard model architecture loaded") |
|
|
|
|
|
|
|
|
print(f"\n💾 Step 2/4: Loading weights from Hugging Face repository: {weights_repo_id}") |
|
|
|
|
|
|
|
|
try: |
|
|
index_path = hf_hub_download( |
|
|
repo_id=weights_repo_id, |
|
|
filename="model.safetensors.index.json", |
|
|
local_dir=None |
|
|
) |
|
|
|
|
|
|
|
|
from safetensors import safe_open |
|
|
import json |
|
|
|
|
|
with open(index_path, 'r') as f: |
|
|
weight_map = json.load(f)["weight_map"] |
|
|
|
|
|
state_dict = {} |
|
|
for weight_file in set(weight_map.values()): |
|
|
file_path = hf_hub_download( |
|
|
repo_id=weights_repo_id, |
|
|
filename=weight_file, |
|
|
local_dir=None |
|
|
) |
|
|
with safe_open(file_path, framework="pt", device="cpu") as f: |
|
|
for key in f.keys(): |
|
|
state_dict[key] = f.get_tensor(key) |
|
|
|
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
print(f"✅ Weights loaded (sharded)") |
|
|
except Exception: |
|
|
|
|
|
try: |
|
|
weights_path = hf_hub_download( |
|
|
repo_id=weights_repo_id, |
|
|
filename="model.safetensors", |
|
|
local_dir=None |
|
|
) |
|
|
state_dict = load_file(weights_path) |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
print(f"✅ Weights loaded (single file)") |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load weights from Hugging Face: {e}") |
|
|
|
|
|
|
|
|
print(f"\n🔧 Step 3/4: Applying QAT replacement ({quant_method})...") |
|
|
replace_modules_for_qat(model, quant_method, skip_lm_head=skip_lm_head) |
|
|
print("✅ QAT replacement completed") |
|
|
|
|
|
|
|
|
print(f"\n📝 Step 4/4: Loading Tokenizer from Hugging Face repository: {weights_repo_id}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(weights_repo_id) |
|
|
print("✅ Tokenizer loaded") |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("✅ Model loading completed!") |
|
|
print("=" * 70) |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
model, tokenizer = load_model() |
|
|
|
|
|
|
|
|
print("\n🧪 Testing generation...") |
|
|
prompt = "Hello, how are you?" |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=50, |
|
|
do_sample=True, |
|
|
temperature=0.7 |
|
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
print(f"Prompt: {prompt}") |
|
|
print(f"Response: {response}") |
|
|
|
|
|
|