InfoLens / backend /models /model_manager.py
dqy08's picture
动画交互改进,UI改进;chat页面改进
2764e14
Raw
History Blame Contribute Delete
8.3 kB
"""模型管理:base / instruct 双槽位,HF 权重缓存共用。
加载约定(由简到繁):
- ``ensure_slot_weights_loaded(slot)``:仅保证该槽位 HF 权重在 ``_hf_loaded`` 中(归因、tokenize)。
- ``ensure_slot_ready(slot)``:槽位可推理;base 另挂 ``project_registry`` / QwenLM(信息密度)。
- 业务入口:信息密度 ``ensure_base_slot_ready()``;语义分析默认 instruct;续写由请求 ``model`` 选槽位(``ensure_slot_ready``)。
"""
from enum import Enum
import threading
from backend.models import REGISTERED_MODELS
from backend.models.project_registry import ModelRegistry
from backend.models.device import DeviceManager
from backend.models.model_loader import attn_implementation_for_device, load_causal_lm, load_tokenizer
from model_paths import DEFAULT_BASE_MODEL, DEFAULT_INSTRUCT_MODEL, resolve_hf_path
project_registry = ModelRegistry(REGISTERED_MODELS)
_init_lock = threading.Lock()
# 统一推理锁:信息密度分析与 instruct 路径共用,确保模型推理串行执行
inference_lock = threading.Lock()
# 按 HuggingFace 路径去重的已加载模型缓存(两槽位共用)
_hf_load_lock = threading.Lock()
_hf_loaded: dict[str, tuple] = {}
class ModelSlot(str, Enum):
"""与 CLI --base_model / --instruct_model 对应的两个对等槽位。"""
BASE = "base"
INSTRUCT = "instruct"
CONFIGURED_SLOTS: tuple[ModelSlot, ...] = (ModelSlot.BASE, ModelSlot.INSTRUCT)
def _resolved_hf_path_for_slot(slot: ModelSlot) -> str:
"""由应用上下文解析槽位对应的 HuggingFace 路径(或本地路径字符串)。"""
from backend.platform.app_context import get_app_context
try:
context = get_app_context(prefer_module_context=True)
except RuntimeError:
if slot == ModelSlot.BASE:
return resolve_hf_path(DEFAULT_BASE_MODEL)
if slot == ModelSlot.INSTRUCT:
return resolve_hf_path(DEFAULT_INSTRUCT_MODEL)
raise ValueError(f"unknown ModelSlot: {slot!r}") from None
if slot == ModelSlot.BASE:
return resolve_hf_path(context.base_model_id or DEFAULT_BASE_MODEL)
if slot == ModelSlot.INSTRUCT:
return resolve_hf_path(context.instruct_model_id or DEFAULT_INSTRUCT_MODEL)
raise ValueError(f"unknown ModelSlot: {slot!r}")
def ensure_slot_weights_loaded(slot: ModelSlot):
"""
加载指定槽位权重(若未缓存)。
返回 (tokenizer, model, device)。
"""
return ensure_model_loaded(_resolved_hf_path_for_slot(slot))
def ensure_model_loaded(resolved_hf_path: str):
"""
唯一底层加载入口:保证 resolved_hf_path 对应权重已加载。
返回 (tokenizer, model, device),其中 device 为模型参数所在 device。
"""
with _hf_load_lock:
hit = _hf_loaded.get(resolved_hf_path)
if hit is not None:
return hit
device = DeviceManager.get_device()
display = resolved_hf_path.split("/")[-1] if "/" in resolved_hf_path else resolved_hf_path
print(f"📦 正在加载模型权重: {display}")
tokenizer = load_tokenizer(resolved_hf_path)
model = load_causal_lm(
resolved_hf_path,
device,
attn_implementation=attn_implementation_for_device(device),
)
for p in model.parameters():
p.requires_grad_(False)
model_device = next(model.parameters()).device
device_name = DeviceManager.get_device_name(device)
print(f"✓ {display} 已加载 ({device_name})")
out = (tokenizer, model, model_device)
_hf_loaded[resolved_hf_path] = out
return out
def ensure_project_loaded(project_name: str):
"""确保项目已加载,如果未加载则加载它"""
if not project_name:
raise ValueError("model name is required")
if not project_registry.is_available(project_name):
raise KeyError(project_name)
try:
return project_registry.ensure_loaded(project_name)
except KeyError:
raise
except Exception as exc: # noqa: BLE001 - propagate detailed message
raise RuntimeError(f"模型 '{project_name}' 加载失败: {exc}") from exc
def _register_base_qwenlm_if_needed():
"""
信息密度路径:在 base 槽位权重已就绪后,注册 project_registry 中的 QwenLM 实例。
instruct 槽位无对应 registry 包装。
"""
from backend.platform.app_context import get_app_context
context = get_app_context(prefer_module_context=True)
selected_name = context.base_model_id
if not selected_name:
raise ValueError("未指定 base 模型 id")
if selected_name in project_registry:
_ensure_default_project_ready(selected_name)
return
if not project_registry.is_available(selected_name):
raise KeyError(f"模型 '{selected_name}' 未找到,可用模型: {list(REGISTERED_MODELS.keys())}")
try:
project_registry.load(selected_name)
_ensure_default_project_ready(selected_name)
except Exception as exc: # noqa: BLE001
raise RuntimeError(f"模型 '{selected_name}' 加载失败: {exc}") from exc
def preload_all_slots():
"""
启动预载(非 --no_auto_load):对 CONFIGURED_SLOTS 各解析 HF 路径,去重后加载全部权重,
再注册 base 槽位 QwenLM 项目。
"""
from backend.platform.app_context import get_app_context
get_app_context(prefer_module_context=True)
paths = {_resolved_hf_path_for_slot(s) for s in CONFIGURED_SLOTS}
with _init_lock:
for path in paths:
ensure_model_loaded(path)
_register_base_qwenlm_if_needed()
def ensure_slot_ready(slot: ModelSlot):
"""
槽位业务就绪:保证该槽位后续推理所需状态已备好。
- 两槽位均先保证 HF 权重已加载,返回 (tokenizer, model, device)。
- base 另需将 QwenLM 挂入 project_registry(信息密度);instruct 无 registry 步骤。
"""
from backend.platform.app_context import get_app_context
get_app_context(prefer_module_context=True)
if slot == ModelSlot.BASE:
with _init_lock:
out = ensure_slot_weights_loaded(ModelSlot.BASE)
_register_base_qwenlm_if_needed()
return out
if slot == ModelSlot.INSTRUCT:
return ensure_slot_weights_loaded(ModelSlot.INSTRUCT)
raise ValueError(f"unknown ModelSlot: {slot!r}")
def ensure_base_slot_ready():
"""信息密度等业务:``ensure_slot_ready(ModelSlot.BASE)``。"""
return ensure_slot_ready(ModelSlot.BASE)
def ensure_instruct_slot_ready():
"""语义分析 / 续写:``ensure_slot_ready(ModelSlot.INSTRUCT)``。"""
return ensure_slot_ready(ModelSlot.INSTRUCT)
def get_current_model_max_token_length() -> int:
"""
查询当前生效 base 模型的 max_token_length 参数。
优先从已加载的模型实例获取,未加载时取 default_model.default_cpu_machine 配置。
"""
from backend.platform.app_context import get_app_context
from backend.platform.runtime_config import RUNTIME_CONFIGS
try:
context = get_app_context(prefer_module_context=True)
model_name = context.base_model_id or DEFAULT_BASE_MODEL
except RuntimeError:
model_name = "default_model"
project = project_registry.get(model_name)
if project is not None and hasattr(project.lm, "max_length"):
return project.lm.max_length
return RUNTIME_CONFIGS["default_model"]["default_cpu_machine"]["max_token_length"]
def _ensure_default_project_ready(selected_name: str):
"""确保默认项目已准备好"""
if not selected_name:
return
if selected_name in project_registry:
return
print(f"⚠️ 默认模型未缓存,正在预加载: {selected_name}")
project_registry.ensure_loaded(selected_name)
def get_instruct_model_display_name() -> str:
"""返回 instruct 槽位 HuggingFace 路径(用于结果中的 model 字段)。"""
return _resolved_hf_path_for_slot(ModelSlot.INSTRUCT)
def get_base_model_display_name() -> str:
"""返回 base 槽位 HuggingFace 路径(用于结果中的 model 字段)。"""
return _resolved_hf_path_for_slot(ModelSlot.BASE)