""" Model loading script Load Fairy2i-W2 model from Hugging Face repository. Usage: from load_model import load_model model, tokenizer = load_model() """ import torch from transformers import AutoModelForCausalLM, AutoTokenizer from safetensors.torch import load_file from huggingface_hub import hf_hub_download import os import sys # Add current directory to path for importing qat_modules current_dir = os.path.dirname(os.path.abspath(__file__)) if current_dir not in sys.path: sys.path.insert(0, current_dir) from qat_modules import replace_modules_for_qat def load_model( device="cuda" if torch.cuda.is_available() else "cpu", torch_dtype=torch.bfloat16 ): """ Load Fairy2i-W2 model: standard architecture + custom weights + QAT linear layer replacement Load weights and tokenizer from Hugging Face repository. Args: device: Device, default auto-select torch_dtype: Data type, default torch.bfloat16 Returns: model, tokenizer """ # Configuration parameters base_model_id = "meta-llama/Llama-2-7b-hf" weights_repo_id = "PKU-DS-LAB/Fairy2i-W2" quant_method = "complex_phase_v2" skip_lm_head = False print("=" * 70) print("Loading Fairy2i-W2 Model") print("=" * 70) # Step 1: Load standard model architecture print(f"\n๐Ÿ“ฅ Step 1/4: Loading standard model architecture: {base_model_id}") model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch_dtype, device_map=device, trust_remote_code=False ) print("โœ… Standard model architecture loaded") # Step 2: Load custom weights print(f"\n๐Ÿ’พ Step 2/4: Loading weights from Hugging Face repository: {weights_repo_id}") # Check for sharded weights try: index_path = hf_hub_download( repo_id=weights_repo_id, filename="model.safetensors.index.json", local_dir=None ) # Sharded weights from safetensors import safe_open import json with open(index_path, 'r') as f: weight_map = json.load(f)["weight_map"] state_dict = {} for weight_file in set(weight_map.values()): file_path = hf_hub_download( repo_id=weights_repo_id, filename=weight_file, local_dir=None ) with safe_open(file_path, framework="pt", device="cpu") as f: for key in f.keys(): state_dict[key] = f.get_tensor(key) model.load_state_dict(state_dict, strict=False) print(f"โœ… Weights loaded (sharded)") except Exception: # Single weight file try: weights_path = hf_hub_download( repo_id=weights_repo_id, filename="model.safetensors", local_dir=None ) state_dict = load_file(weights_path) model.load_state_dict(state_dict, strict=False) print(f"โœ… Weights loaded (single file)") except Exception as e: raise RuntimeError(f"Failed to load weights from Hugging Face: {e}") # Step 3: Apply QAT replacement print(f"\n๐Ÿ”ง Step 3/4: Applying QAT replacement ({quant_method})...") replace_modules_for_qat(model, quant_method, skip_lm_head=skip_lm_head) print("โœ… QAT replacement completed") # Step 4: Load tokenizer print(f"\n๐Ÿ“ Step 4/4: Loading Tokenizer from Hugging Face repository: {weights_repo_id}") tokenizer = AutoTokenizer.from_pretrained(weights_repo_id) print("โœ… Tokenizer loaded") print("\n" + "=" * 70) print("โœ… Model loading completed!") print("=" * 70) return model, tokenizer if __name__ == "__main__": # Example: Load model model, tokenizer = load_model() # Test generation print("\n๐Ÿงช Testing generation...") prompt = "Hello, how are you?" inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=50, do_sample=True, temperature=0.7 ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) print(f"Prompt: {prompt}") print(f"Response: {response}")