""" Model loader for safetensors-based quantized models """ import torch import json import os from safetensors.torch import load_file from transformers import AutoModel from .quantization import QuantizedLinear def load_quantized_model(model_path, device="cuda", trust_remote_code=True): """ Load quantized model from safetensors + metadata Args: model_path: Path to model directory or HF repo device: Device to load on trust_remote_code: Required for custom code Returns: Loaded quantized model """ # Load base model architecture print("Loading base model architecture...") base_model = AutoModel.from_pretrained( "deepseek-ai/DeepSeek-OCR", trust_remote_code=True, torch_dtype=torch.bfloat16 ) # Load safetensors print("Loading quantized weights from safetensors...") if os.path.isdir(model_path): safetensors_path = os.path.join(model_path, "model.safetensors") metadata_path = os.path.join(model_path, "quantization_config.json") else: # If it's a repo name, download from HF from huggingface_hub import hf_hub_download safetensors_path = hf_hub_download(repo_id=model_path, filename="model.safetensors") metadata_path = hf_hub_download(repo_id=model_path, filename="quantization_config.json") state_dict = load_file(safetensors_path) # Load quantization metadata with open(metadata_path, 'r') as f: quant_metadata = json.load(f) quantized_layers = quant_metadata['quantized_layers'] # Reconstruct QuantizedLinear layers print(f"Reconstructing {len(quantized_layers)} quantized layers...") for layer_name, layer_info in quantized_layers.items(): weight_quantized = state_dict[f"{layer_name}.weight_quantized"] scale = state_dict[f"{layer_name}.scale"] zero_point = state_dict[f"{layer_name}.zero_point"] bias = state_dict.get(f"{layer_name}.bias", None) quantized_linear = QuantizedLinear( in_features=layer_info['in_features'], out_features=layer_info['out_features'], bits=layer_info['bits'], weight_data=weight_quantized, scale=scale, zero_point=zero_point, bias=bias ) # Replace in model parts = layer_name.split('.') parent = base_model for part in parts[:-1]: parent = getattr(parent, part) setattr(parent, parts[-1], quantized_linear) print(f"✅ Model loaded successfully!") print(f" Compression: {quant_metadata['stats']['compression_ratio']:.2f}x") print(f" Size: {quant_metadata['stats']['compressed_size_mb']:.0f} MB") return base_model.to(device) # Alias for HF compatibility def from_pretrained(model_path, **kwargs): """HuggingFace-style loading""" device = kwargs.get('device', 'cuda') return load_quantized_model(model_path, device=device)