| """模型管理:base / instruct 双槽位,HF 权重缓存共用。 |
| |
| 加载约定(由简到繁): |
| - ``ensure_slot_weights_loaded(slot)``:仅保证该槽位 HF 权重在 ``_hf_loaded`` 中(归因、tokenize)。 |
| - ``ensure_slot_ready(slot)``:槽位可推理;base 另挂 ``project_registry`` / QwenLM(信息密度)。 |
| - 业务入口:信息密度 ``ensure_base_slot_ready()``;语义分析默认 instruct;续写由请求 ``model`` 选槽位(``ensure_slot_ready``)。 |
| """ |
| from enum import Enum |
| import threading |
|
|
| from backend.models import REGISTERED_MODELS |
| from backend.models.project_registry import ModelRegistry |
| from backend.models.device import DeviceManager |
| from backend.models.model_loader import attn_implementation_for_device, load_causal_lm, load_tokenizer |
|
|
| from model_paths import DEFAULT_BASE_MODEL, DEFAULT_INSTRUCT_MODEL, resolve_hf_path |
|
|
| project_registry = ModelRegistry(REGISTERED_MODELS) |
| _init_lock = threading.Lock() |
|
|
| |
| inference_lock = threading.Lock() |
|
|
| |
| _hf_load_lock = threading.Lock() |
| _hf_loaded: dict[str, tuple] = {} |
|
|
|
|
| class ModelSlot(str, Enum): |
| """与 CLI --base_model / --instruct_model 对应的两个对等槽位。""" |
|
|
| BASE = "base" |
| INSTRUCT = "instruct" |
|
|
|
|
| CONFIGURED_SLOTS: tuple[ModelSlot, ...] = (ModelSlot.BASE, ModelSlot.INSTRUCT) |
|
|
|
|
| def _resolved_hf_path_for_slot(slot: ModelSlot) -> str: |
| """由应用上下文解析槽位对应的 HuggingFace 路径(或本地路径字符串)。""" |
| from backend.platform.app_context import get_app_context |
|
|
| try: |
| context = get_app_context(prefer_module_context=True) |
| except RuntimeError: |
| if slot == ModelSlot.BASE: |
| return resolve_hf_path(DEFAULT_BASE_MODEL) |
| if slot == ModelSlot.INSTRUCT: |
| return resolve_hf_path(DEFAULT_INSTRUCT_MODEL) |
| raise ValueError(f"unknown ModelSlot: {slot!r}") from None |
|
|
| if slot == ModelSlot.BASE: |
| return resolve_hf_path(context.base_model_id or DEFAULT_BASE_MODEL) |
| if slot == ModelSlot.INSTRUCT: |
| return resolve_hf_path(context.instruct_model_id or DEFAULT_INSTRUCT_MODEL) |
| raise ValueError(f"unknown ModelSlot: {slot!r}") |
|
|
|
|
| def ensure_slot_weights_loaded(slot: ModelSlot): |
| """ |
| 加载指定槽位权重(若未缓存)。 |
| 返回 (tokenizer, model, device)。 |
| """ |
| return ensure_model_loaded(_resolved_hf_path_for_slot(slot)) |
|
|
|
|
| def ensure_model_loaded(resolved_hf_path: str): |
| """ |
| 唯一底层加载入口:保证 resolved_hf_path 对应权重已加载。 |
| 返回 (tokenizer, model, device),其中 device 为模型参数所在 device。 |
| """ |
| with _hf_load_lock: |
| hit = _hf_loaded.get(resolved_hf_path) |
| if hit is not None: |
| return hit |
|
|
| device = DeviceManager.get_device() |
| display = resolved_hf_path.split("/")[-1] if "/" in resolved_hf_path else resolved_hf_path |
| print(f"📦 正在加载模型权重: {display}") |
| tokenizer = load_tokenizer(resolved_hf_path) |
| model = load_causal_lm( |
| resolved_hf_path, |
| device, |
| attn_implementation=attn_implementation_for_device(device), |
| ) |
| for p in model.parameters(): |
| p.requires_grad_(False) |
| model_device = next(model.parameters()).device |
| device_name = DeviceManager.get_device_name(device) |
| print(f"✓ {display} 已加载 ({device_name})") |
| out = (tokenizer, model, model_device) |
| _hf_loaded[resolved_hf_path] = out |
| return out |
|
|
|
|
| def ensure_project_loaded(project_name: str): |
| """确保项目已加载,如果未加载则加载它""" |
| if not project_name: |
| raise ValueError("model name is required") |
| if not project_registry.is_available(project_name): |
| raise KeyError(project_name) |
| try: |
| return project_registry.ensure_loaded(project_name) |
| except KeyError: |
| raise |
| except Exception as exc: |
| raise RuntimeError(f"模型 '{project_name}' 加载失败: {exc}") from exc |
|
|
|
|
| def _register_base_qwenlm_if_needed(): |
| """ |
| 信息密度路径:在 base 槽位权重已就绪后,注册 project_registry 中的 QwenLM 实例。 |
| instruct 槽位无对应 registry 包装。 |
| """ |
| from backend.platform.app_context import get_app_context |
|
|
| context = get_app_context(prefer_module_context=True) |
| selected_name = context.base_model_id |
|
|
| if not selected_name: |
| raise ValueError("未指定 base 模型 id") |
|
|
| if selected_name in project_registry: |
| _ensure_default_project_ready(selected_name) |
| return |
|
|
| if not project_registry.is_available(selected_name): |
| raise KeyError(f"模型 '{selected_name}' 未找到,可用模型: {list(REGISTERED_MODELS.keys())}") |
|
|
| try: |
| project_registry.load(selected_name) |
| _ensure_default_project_ready(selected_name) |
| except Exception as exc: |
| raise RuntimeError(f"模型 '{selected_name}' 加载失败: {exc}") from exc |
|
|
|
|
| def preload_all_slots(): |
| """ |
| 启动预载(非 --no_auto_load):对 CONFIGURED_SLOTS 各解析 HF 路径,去重后加载全部权重, |
| 再注册 base 槽位 QwenLM 项目。 |
| """ |
| from backend.platform.app_context import get_app_context |
|
|
| get_app_context(prefer_module_context=True) |
|
|
| paths = {_resolved_hf_path_for_slot(s) for s in CONFIGURED_SLOTS} |
|
|
| with _init_lock: |
| for path in paths: |
| ensure_model_loaded(path) |
| _register_base_qwenlm_if_needed() |
|
|
|
|
| def ensure_slot_ready(slot: ModelSlot): |
| """ |
| 槽位业务就绪:保证该槽位后续推理所需状态已备好。 |
| |
| - 两槽位均先保证 HF 权重已加载,返回 (tokenizer, model, device)。 |
| - base 另需将 QwenLM 挂入 project_registry(信息密度);instruct 无 registry 步骤。 |
| """ |
| from backend.platform.app_context import get_app_context |
|
|
| get_app_context(prefer_module_context=True) |
|
|
| if slot == ModelSlot.BASE: |
| with _init_lock: |
| out = ensure_slot_weights_loaded(ModelSlot.BASE) |
| _register_base_qwenlm_if_needed() |
| return out |
| if slot == ModelSlot.INSTRUCT: |
| return ensure_slot_weights_loaded(ModelSlot.INSTRUCT) |
| raise ValueError(f"unknown ModelSlot: {slot!r}") |
|
|
|
|
| def ensure_base_slot_ready(): |
| """信息密度等业务:``ensure_slot_ready(ModelSlot.BASE)``。""" |
| return ensure_slot_ready(ModelSlot.BASE) |
|
|
|
|
| def ensure_instruct_slot_ready(): |
| """语义分析 / 续写:``ensure_slot_ready(ModelSlot.INSTRUCT)``。""" |
| return ensure_slot_ready(ModelSlot.INSTRUCT) |
|
|
|
|
| def get_current_model_max_token_length() -> int: |
| """ |
| 查询当前生效 base 模型的 max_token_length 参数。 |
| 优先从已加载的模型实例获取,未加载时取 default_model.default_cpu_machine 配置。 |
| """ |
| from backend.platform.app_context import get_app_context |
| from backend.platform.runtime_config import RUNTIME_CONFIGS |
|
|
| try: |
| context = get_app_context(prefer_module_context=True) |
| model_name = context.base_model_id or DEFAULT_BASE_MODEL |
| except RuntimeError: |
| model_name = "default_model" |
|
|
| project = project_registry.get(model_name) |
| if project is not None and hasattr(project.lm, "max_length"): |
| return project.lm.max_length |
| return RUNTIME_CONFIGS["default_model"]["default_cpu_machine"]["max_token_length"] |
|
|
|
|
| def _ensure_default_project_ready(selected_name: str): |
| """确保默认项目已准备好""" |
| if not selected_name: |
| return |
| if selected_name in project_registry: |
| return |
| print(f"⚠️ 默认模型未缓存,正在预加载: {selected_name}") |
| project_registry.ensure_loaded(selected_name) |
|
|
|
|
| def get_instruct_model_display_name() -> str: |
| """返回 instruct 槽位 HuggingFace 路径(用于结果中的 model 字段)。""" |
| return _resolved_hf_path_for_slot(ModelSlot.INSTRUCT) |
|
|
|
|
| def get_base_model_display_name() -> str: |
| """返回 base 槽位 HuggingFace 路径(用于结果中的 model 字段)。""" |
| return _resolved_hf_path_for_slot(ModelSlot.BASE) |
|
|