Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """ | |
| Causal LM 模型加载:设备策略与加载逻辑统一封装 | |
| 供 language_checker.QwenLM(信息密度分析)与 model_manager.ensure_model_loaded 共用, | |
| 消除重复的设备分支、量化配置、加载后处理等逻辑。 | |
| 加载策略说明: | |
| - INT8 量化:bitsandbytes 8bit,device_map="cpu"/"auto",减少约 4 倍内存 | |
| - CPU 手动模式:无 device_map,.to(device),默认 float32 | |
| - GPU/MPS 自动模式:device_map="auto",float16 | |
| dtype/设备与因果 LM 在「仅前缀 forward」vs「整段 forward」同位置 logits 的逐元素对比: | |
| float32(CPU)常完全一致;float16(MPS/CUDA)可能因实现路径出现约 1e-2 量级差,非掩码错误。 | |
| 复现与说明见 scripts/reproduce_logits_triple_path.py、scripts/prove_fp16_gemm_shape_sensitivity.py。 | |
| """ | |
| import os | |
| import time | |
| from typing import Any, Dict, Optional | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers.utils import is_flash_attn_2_available | |
| from .device import DeviceManager | |
| from .load_utils import resolve_and_load | |
| from backend.platform.quantization_config import get_quantization_config | |
| def get_device_load_strategy(device: torch.device) -> Dict[str, Any]: | |
| """ | |
| 根据设备推断加载策略(device_map、dtype、use_int8 等)。 | |
| 打印设备模式说明,与 QwenLM 风格一致。 | |
| 环境变量:FORCE_INT8=1 / CPU_FORCE_BFLOAT16=1 | |
| 返回供 load_causal_lm 使用的参数字典。 | |
| """ | |
| qconfig = get_quantization_config(device) | |
| use_int8 = qconfig.use_int8 | |
| device_map = None | |
| dtype = qconfig.dtype | |
| use_low_cpu_mem = False | |
| if device.type == "cpu": | |
| print("🔧 CPU 模式:手动控制设备分配") | |
| if use_int8: | |
| device_map = "cpu" | |
| print("⚠️ 启用 INT8 量化(FORCE_INT8=1,实验性,在某些情况下会降低性能)") | |
| elif dtype == torch.bfloat16: | |
| use_low_cpu_mem = True | |
| print("⚠️ 启用 bfloat16(CPU_FORCE_BFLOAT16=1,需硬件支持 AVX-512_BF16 或 AMX,否则可能极慢)") | |
| else: | |
| use_low_cpu_mem = True | |
| print("🔧 dtype: float32") # 默认: float32 | |
| elif device.type == "cuda": | |
| print("🔧 CUDA 模式:自动设备分配") | |
| device_map = "auto" | |
| use_low_cpu_mem = True | |
| if use_int8: | |
| print("⚠️ 启用 INT8 量化(FORCE_INT8=1)") | |
| else: | |
| print("🔧 dtype: float16") | |
| print("🔧 device_map: auto") | |
| else: | |
| # MPS 模式:自动设备分配 + float16(MPS 不支持 INT8 量化) | |
| print(f"🔧 {device.type.upper()} 模式:自动设备分配") | |
| if os.environ.get("FORCE_INT8") == "1": | |
| print("⚠️ MPS 不支持 INT8 量化,已忽略 FORCE_INT8=1 环境变量") | |
| device_map = "auto" | |
| use_low_cpu_mem = True | |
| print("🔧 dtype: float16") | |
| print("🔧 device_map: auto") | |
| return { | |
| "device_map": device_map, | |
| "dtype": dtype, | |
| "use_low_cpu_mem": use_low_cpu_mem, | |
| "use_int8": use_int8, | |
| } | |
| def attn_implementation_for_device(device: torch.device) -> str: | |
| """ | |
| 非 CUDA:eager,兼容性最好(CPU / MPS 等)。 | |
| CUDA:已安装 flash-attn 时用 flash_attention_2;否则 eager(不使用 sdpa)。 | |
| """ | |
| if device.type != "cuda": | |
| return "eager" | |
| if is_flash_attn_2_available(): | |
| return "flash_attention_2" | |
| return "eager" | |
| def load_causal_lm( | |
| model_path: str, | |
| device: torch.device, | |
| *, | |
| attn_implementation: Optional[str] = None, | |
| extra_model_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> torch.nn.Module: | |
| """ | |
| 加载 Causal LM 模型,统一处理设备策略、量化、加载后处理。 | |
| Args: | |
| model_path: HuggingFace 模型路径或本地路径 | |
| device: 目标设备 | |
| attn_implementation: 可选;未传时可在外层用 attn_implementation_for_device(device) | |
| extra_model_kwargs: 可选,额外传给 from_pretrained 的参数 | |
| Returns: | |
| 已 eval() 的模型 | |
| """ | |
| strategy = get_device_load_strategy(device) | |
| device_map = strategy["device_map"] | |
| dtype = strategy["dtype"] | |
| use_low_cpu_mem = strategy["use_low_cpu_mem"] | |
| use_int8 = strategy["use_int8"] | |
| load_kw: Dict[str, Any] = { | |
| "trust_remote_code": True, | |
| "low_cpu_mem_usage": use_low_cpu_mem or use_int8, | |
| } | |
| if attn_implementation is not None: | |
| load_kw["attn_implementation"] = attn_implementation | |
| if extra_model_kwargs: | |
| load_kw.update(extra_model_kwargs) | |
| def _load(path: str, lf: bool): | |
| kw = dict(local_files_only=lf, **load_kw) | |
| if use_int8: | |
| from transformers import BitsAndBytesConfig | |
| return AutoModelForCausalLM.from_pretrained( | |
| path, | |
| quantization_config=BitsAndBytesConfig(load_in_8bit=True), | |
| device_map=device_map, | |
| **kw, | |
| ) | |
| if device_map: | |
| return AutoModelForCausalLM.from_pretrained( | |
| path, | |
| device_map=device_map, | |
| dtype=dtype, | |
| **kw, | |
| ) | |
| return AutoModelForCausalLM.from_pretrained( | |
| path, dtype=dtype, **kw | |
| ).to(device) | |
| t0 = time.perf_counter() | |
| model = resolve_and_load(model_path, _load) | |
| load_time = time.perf_counter() - t0 | |
| DeviceManager.print_model_load_stats(model, load_time) | |
| model.eval() | |
| if device.type == "cuda": | |
| device_idx = device.index if device.index is not None else 0 | |
| DeviceManager.print_cuda_memory_summary(device=device_idx) | |
| return model | |
| def load_tokenizer(model_path: str): | |
| """加载 tokenizer。本地优先时先解析为缓存路径,避免 tokenizer 内部 model_info 联网。""" | |
| def _load(path: str, lf: bool): | |
| return AutoTokenizer.from_pretrained( | |
| path, trust_remote_code=True, local_files_only=lf | |
| ) | |
| return resolve_and_load(model_path, _load) | |