"""模型切换 API""" import gc import os from typing import Optional import torch from model_paths import MODEL_PATHS from backend.model_manager import project_registry from backend.app_context import get_app_context from backend.api.utils import require_admin def get_available_models(): """获取所有可用的模型列表""" return { 'success': True, 'models': list(MODEL_PATHS.keys()) }, 200 def _get_device_type() -> str: """获取当前设备类型""" if torch.cuda.is_available(): return "cuda" elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return "mps" else: return "cpu" def _restore_env_vars(old_force_int8: Optional[str], old_force_bfloat16: Optional[str]) -> None: """恢复环境变量配置""" if old_force_int8 is not None: os.environ['FORCE_INT8'] = old_force_int8 else: os.environ.pop('FORCE_INT8', None) if old_force_bfloat16 is not None: os.environ['CPU_FORCE_BFLOAT16'] = old_force_bfloat16 else: os.environ.pop('CPU_FORCE_BFLOAT16', None) def get_current_model(): """获取当前使用的模型及量化配置""" # 使用模块级上下文以获取持久化的模型状态 context = get_app_context(prefer_module_context=True) device_type = _get_device_type() return { 'success': True, 'model': context.model_name, 'loading': context.model_loading, 'device_type': device_type, 'use_int8': os.environ.get('FORCE_INT8') == '1', 'use_bfloat16': os.environ.get('CPU_FORCE_BFLOAT16') == '1' }, 200 @require_admin def switch_model(switch_request): """ 切换模型(需要管理员权限) Args: switch_request: 切换请求字典,包含: - model: 目标模型名称 - use_int8: 是否使用 INT8 量化(可选) - use_bfloat16: 是否使用 bfloat16(可选,仅CPU) Returns: (响应字典, 状态码) 元组 """ target_model = switch_request.get('model') use_int8 = switch_request.get('use_int8', False) use_bfloat16 = switch_request.get('use_bfloat16', False) # 验证请求 if not target_model: return { 'success': False, 'message': 'Missing model parameter' }, 400 # 检查模型是否可用 if target_model not in MODEL_PATHS: available_models = list(MODEL_PATHS.keys()) return { 'success': False, 'message': f'Model {target_model} does not exist. Available models: {", ".join(available_models)}' }, 404 # 获取设备类型 device_type = _get_device_type() # 验证量化参数与设备兼容性 if use_int8 and device_type == "mps": return { 'success': False, 'message': 'INT8 quantization is not supported on MPS device' }, 400 if use_bfloat16 and device_type != "cpu": return { 'success': False, 'message': 'bfloat16 quantization is only supported on CPU device' }, 400 if use_int8 and use_bfloat16: return { 'success': False, 'message': 'Cannot enable both INT8 and bfloat16 quantization' }, 400 # 使用模块级上下文以确保状态修改持久化(不会被后续请求重置) context = get_app_context(prefer_module_context=True) current_model = context.model_name # 保存当前环境变量配置(用于回滚) old_force_int8 = os.environ.get('FORCE_INT8') old_force_bfloat16 = os.environ.get('CPU_FORCE_BFLOAT16') # 检查是否已经是目标模型且量化配置相同 current_int8 = os.environ.get('FORCE_INT8') == '1' current_bfloat16 = os.environ.get('CPU_FORCE_BFLOAT16') == '1' if (current_model == target_model and current_int8 == use_int8 and current_bfloat16 == use_bfloat16): return { 'success': True, 'message': f'Already using model {target_model} (same quantization configuration)', 'model': target_model }, 200 # 检查模型是否正在加载中(初始加载或切换) if context.model_loading: return { 'success': False, 'message': 'Model is currently loading, please try again later' }, 503 try: # 标记开始加载 context.set_model_loading(True) print(f"🔄 开始切换模型: {current_model} -> {target_model}") # 设置新的量化环境变量 if use_int8: os.environ['FORCE_INT8'] = '1' print(f" 设置量化: INT8") else: os.environ.pop('FORCE_INT8', None) if use_bfloat16: os.environ['CPU_FORCE_BFLOAT16'] = '1' print(f" 设置量化: bfloat16") else: os.environ.pop('CPU_FORCE_BFLOAT16', None) # 卸载旧模型 if current_model and current_model in project_registry: print(f" 卸载旧模型: {current_model}") project_registry.unload(current_model) gc.collect() if device_type == "cuda": torch.cuda.empty_cache() elif device_type == "mps": torch.mps.empty_cache() # 加载新模型 print(f" 加载新模型: {target_model}") project_registry.ensure_loaded(target_model) # 更新当前模型 context.set_current_model(target_model) print(f"✅ 模型切换成功: {target_model}") return { 'success': True, 'message': f'Model switched to {target_model}', 'model': target_model }, 200 except KeyError as e: # 模型不存在(虽然前面已经检查过,但以防万一) print(f"❌ 模型切换失败: 模型 {target_model} 未注册") # 回滚:恢复旧模型名称和环境变量 context.set_current_model(current_model) _restore_env_vars(old_force_int8, old_force_bfloat16) return { 'success': False, 'message': f'Model {target_model} is not registered' }, 404 except Exception as e: # 加载失败,尝试回滚 print(f"❌ 模型切换失败: {e}") print(f" 尝试回滚到旧模型: {current_model}") try: # 回滚:恢复环境变量和重新加载旧模型 _restore_env_vars(old_force_int8, old_force_bfloat16) if current_model: project_registry.ensure_loaded(current_model) context.set_current_model(current_model) print(f"✅ 已回滚到旧模型: {current_model}") except Exception as rollback_error: print(f"⚠️ 回滚失败: {rollback_error}") return { 'success': False, 'message': f'Model switch failed: {str(e)}' }, 500 finally: # 无论成功还是失败,都要清除加载标志 context.set_model_loading(False) gc.collect()