| | """模型切换 API""" |
| | import gc |
| | import os |
| | from typing import Optional |
| |
|
| | import torch |
| | from model_paths import MODEL_PATHS |
| | from backend.model_manager import project_registry |
| | from backend.app_context import get_app_context |
| | from backend.api.utils import require_admin |
| |
|
| |
|
| | def get_available_models(): |
| | """获取所有可用的模型列表""" |
| | return { |
| | 'success': True, |
| | 'models': list(MODEL_PATHS.keys()) |
| | }, 200 |
| |
|
| |
|
| | def _get_device_type() -> str: |
| | """获取当前设备类型""" |
| | if torch.cuda.is_available(): |
| | return "cuda" |
| | elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
| | return "mps" |
| | else: |
| | return "cpu" |
| |
|
| |
|
| | def _restore_env_vars(old_force_int8: Optional[str], old_force_bfloat16: Optional[str]) -> None: |
| | """恢复环境变量配置""" |
| | if old_force_int8 is not None: |
| | os.environ['FORCE_INT8'] = old_force_int8 |
| | else: |
| | os.environ.pop('FORCE_INT8', None) |
| | |
| | if old_force_bfloat16 is not None: |
| | os.environ['CPU_FORCE_BFLOAT16'] = old_force_bfloat16 |
| | else: |
| | os.environ.pop('CPU_FORCE_BFLOAT16', None) |
| |
|
| |
|
| | def get_current_model(): |
| | """获取当前使用的模型及量化配置""" |
| | |
| | context = get_app_context(prefer_module_context=True) |
| | device_type = _get_device_type() |
| | |
| | return { |
| | 'success': True, |
| | 'model': context.model_name, |
| | 'loading': context.model_loading, |
| | 'device_type': device_type, |
| | 'use_int8': os.environ.get('FORCE_INT8') == '1', |
| | 'use_bfloat16': os.environ.get('CPU_FORCE_BFLOAT16') == '1' |
| | }, 200 |
| |
|
| |
|
| | @require_admin |
| | def switch_model(switch_request): |
| | """ |
| | 切换模型(需要管理员权限) |
| | |
| | Args: |
| | switch_request: 切换请求字典,包含: |
| | - model: 目标模型名称 |
| | - use_int8: 是否使用 INT8 量化(可选) |
| | - use_bfloat16: 是否使用 bfloat16(可选,仅CPU) |
| | |
| | Returns: |
| | (响应字典, 状态码) 元组 |
| | """ |
| | target_model = switch_request.get('model') |
| | use_int8 = switch_request.get('use_int8', False) |
| | use_bfloat16 = switch_request.get('use_bfloat16', False) |
| | |
| | |
| | if not target_model: |
| | return { |
| | 'success': False, |
| | 'message': 'Missing model parameter' |
| | }, 400 |
| | |
| | |
| | if target_model not in MODEL_PATHS: |
| | available_models = list(MODEL_PATHS.keys()) |
| | return { |
| | 'success': False, |
| | 'message': f'Model {target_model} does not exist. Available models: {", ".join(available_models)}' |
| | }, 404 |
| | |
| | |
| | device_type = _get_device_type() |
| | |
| | |
| | if use_int8 and device_type == "mps": |
| | return { |
| | 'success': False, |
| | 'message': 'INT8 quantization is not supported on MPS device' |
| | }, 400 |
| | |
| | if use_bfloat16 and device_type != "cpu": |
| | return { |
| | 'success': False, |
| | 'message': 'bfloat16 quantization is only supported on CPU device' |
| | }, 400 |
| | |
| | if use_int8 and use_bfloat16: |
| | return { |
| | 'success': False, |
| | 'message': 'Cannot enable both INT8 and bfloat16 quantization' |
| | }, 400 |
| | |
| | |
| | context = get_app_context(prefer_module_context=True) |
| | current_model = context.model_name |
| | |
| | |
| | old_force_int8 = os.environ.get('FORCE_INT8') |
| | old_force_bfloat16 = os.environ.get('CPU_FORCE_BFLOAT16') |
| | |
| | |
| | current_int8 = os.environ.get('FORCE_INT8') == '1' |
| | current_bfloat16 = os.environ.get('CPU_FORCE_BFLOAT16') == '1' |
| | |
| | if (current_model == target_model and |
| | current_int8 == use_int8 and |
| | current_bfloat16 == use_bfloat16): |
| | return { |
| | 'success': True, |
| | 'message': f'Already using model {target_model} (same quantization configuration)', |
| | 'model': target_model |
| | }, 200 |
| | |
| | |
| | if context.model_loading: |
| | return { |
| | 'success': False, |
| | 'message': 'Model is currently loading, please try again later' |
| | }, 503 |
| | |
| | try: |
| | |
| | context.set_model_loading(True) |
| | print(f"🔄 开始切换模型: {current_model} -> {target_model}") |
| | |
| | |
| | if use_int8: |
| | os.environ['FORCE_INT8'] = '1' |
| | print(f" 设置量化: INT8") |
| | else: |
| | os.environ.pop('FORCE_INT8', None) |
| | |
| | if use_bfloat16: |
| | os.environ['CPU_FORCE_BFLOAT16'] = '1' |
| | print(f" 设置量化: bfloat16") |
| | else: |
| | os.environ.pop('CPU_FORCE_BFLOAT16', None) |
| | |
| | |
| | if current_model and current_model in project_registry: |
| | print(f" 卸载旧模型: {current_model}") |
| | project_registry.unload(current_model) |
| | gc.collect() |
| | if device_type == "cuda": |
| | torch.cuda.empty_cache() |
| | elif device_type == "mps": |
| | torch.mps.empty_cache() |
| | |
| | |
| | print(f" 加载新模型: {target_model}") |
| | project_registry.ensure_loaded(target_model) |
| | |
| | |
| | context.set_current_model(target_model) |
| | |
| | print(f"✅ 模型切换成功: {target_model}") |
| | |
| | return { |
| | 'success': True, |
| | 'message': f'Model switched to {target_model}', |
| | 'model': target_model |
| | }, 200 |
| | |
| | except KeyError as e: |
| | |
| | print(f"❌ 模型切换失败: 模型 {target_model} 未注册") |
| | |
| | context.set_current_model(current_model) |
| | _restore_env_vars(old_force_int8, old_force_bfloat16) |
| | return { |
| | 'success': False, |
| | 'message': f'Model {target_model} is not registered' |
| | }, 404 |
| | |
| | except Exception as e: |
| | |
| | print(f"❌ 模型切换失败: {e}") |
| | print(f" 尝试回滚到旧模型: {current_model}") |
| | |
| | try: |
| | |
| | _restore_env_vars(old_force_int8, old_force_bfloat16) |
| | if current_model: |
| | project_registry.ensure_loaded(current_model) |
| | context.set_current_model(current_model) |
| | print(f"✅ 已回滚到旧模型: {current_model}") |
| | except Exception as rollback_error: |
| | print(f"⚠️ 回滚失败: {rollback_error}") |
| | |
| | return { |
| | 'success': False, |
| | 'message': f'Model switch failed: {str(e)}' |
| | }, 500 |
| | |
| | finally: |
| | |
| | context.set_model_loading(False) |
| | gc.collect() |
| |
|