File size: 7,310 Bytes
c8f3989 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 | """模型切换 API"""
import gc
import os
from typing import Optional
import torch
from model_paths import MODEL_PATHS
from backend.model_manager import project_registry
from backend.app_context import get_app_context
from backend.api.utils import require_admin
def get_available_models():
"""获取所有可用的模型列表"""
return {
'success': True,
'models': list(MODEL_PATHS.keys())
}, 200
def _get_device_type() -> str:
"""获取当前设备类型"""
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
def _restore_env_vars(old_force_int8: Optional[str], old_force_bfloat16: Optional[str]) -> None:
"""恢复环境变量配置"""
if old_force_int8 is not None:
os.environ['FORCE_INT8'] = old_force_int8
else:
os.environ.pop('FORCE_INT8', None)
if old_force_bfloat16 is not None:
os.environ['CPU_FORCE_BFLOAT16'] = old_force_bfloat16
else:
os.environ.pop('CPU_FORCE_BFLOAT16', None)
def get_current_model():
"""获取当前使用的模型及量化配置"""
# 使用模块级上下文以获取持久化的模型状态
context = get_app_context(prefer_module_context=True)
device_type = _get_device_type()
return {
'success': True,
'model': context.model_name,
'loading': context.model_loading,
'device_type': device_type,
'use_int8': os.environ.get('FORCE_INT8') == '1',
'use_bfloat16': os.environ.get('CPU_FORCE_BFLOAT16') == '1'
}, 200
@require_admin
def switch_model(switch_request):
"""
切换模型(需要管理员权限)
Args:
switch_request: 切换请求字典,包含:
- model: 目标模型名称
- use_int8: 是否使用 INT8 量化(可选)
- use_bfloat16: 是否使用 bfloat16(可选,仅CPU)
Returns:
(响应字典, 状态码) 元组
"""
target_model = switch_request.get('model')
use_int8 = switch_request.get('use_int8', False)
use_bfloat16 = switch_request.get('use_bfloat16', False)
# 验证请求
if not target_model:
return {
'success': False,
'message': 'Missing model parameter'
}, 400
# 检查模型是否可用
if target_model not in MODEL_PATHS:
available_models = list(MODEL_PATHS.keys())
return {
'success': False,
'message': f'Model {target_model} does not exist. Available models: {", ".join(available_models)}'
}, 404
# 获取设备类型
device_type = _get_device_type()
# 验证量化参数与设备兼容性
if use_int8 and device_type == "mps":
return {
'success': False,
'message': 'INT8 quantization is not supported on MPS device'
}, 400
if use_bfloat16 and device_type != "cpu":
return {
'success': False,
'message': 'bfloat16 quantization is only supported on CPU device'
}, 400
if use_int8 and use_bfloat16:
return {
'success': False,
'message': 'Cannot enable both INT8 and bfloat16 quantization'
}, 400
# 使用模块级上下文以确保状态修改持久化(不会被后续请求重置)
context = get_app_context(prefer_module_context=True)
current_model = context.model_name
# 保存当前环境变量配置(用于回滚)
old_force_int8 = os.environ.get('FORCE_INT8')
old_force_bfloat16 = os.environ.get('CPU_FORCE_BFLOAT16')
# 检查是否已经是目标模型且量化配置相同
current_int8 = os.environ.get('FORCE_INT8') == '1'
current_bfloat16 = os.environ.get('CPU_FORCE_BFLOAT16') == '1'
if (current_model == target_model and
current_int8 == use_int8 and
current_bfloat16 == use_bfloat16):
return {
'success': True,
'message': f'Already using model {target_model} (same quantization configuration)',
'model': target_model
}, 200
# 检查模型是否正在加载中(初始加载或切换)
if context.model_loading:
return {
'success': False,
'message': 'Model is currently loading, please try again later'
}, 503
try:
# 标记开始加载
context.set_model_loading(True)
print(f"🔄 开始切换模型: {current_model} -> {target_model}")
# 设置新的量化环境变量
if use_int8:
os.environ['FORCE_INT8'] = '1'
print(f" 设置量化: INT8")
else:
os.environ.pop('FORCE_INT8', None)
if use_bfloat16:
os.environ['CPU_FORCE_BFLOAT16'] = '1'
print(f" 设置量化: bfloat16")
else:
os.environ.pop('CPU_FORCE_BFLOAT16', None)
# 卸载旧模型
if current_model and current_model in project_registry:
print(f" 卸载旧模型: {current_model}")
project_registry.unload(current_model)
gc.collect()
if device_type == "cuda":
torch.cuda.empty_cache()
elif device_type == "mps":
torch.mps.empty_cache()
# 加载新模型
print(f" 加载新模型: {target_model}")
project_registry.ensure_loaded(target_model)
# 更新当前模型
context.set_current_model(target_model)
print(f"✅ 模型切换成功: {target_model}")
return {
'success': True,
'message': f'Model switched to {target_model}',
'model': target_model
}, 200
except KeyError as e:
# 模型不存在(虽然前面已经检查过,但以防万一)
print(f"❌ 模型切换失败: 模型 {target_model} 未注册")
# 回滚:恢复旧模型名称和环境变量
context.set_current_model(current_model)
_restore_env_vars(old_force_int8, old_force_bfloat16)
return {
'success': False,
'message': f'Model {target_model} is not registered'
}, 404
except Exception as e:
# 加载失败,尝试回滚
print(f"❌ 模型切换失败: {e}")
print(f" 尝试回滚到旧模型: {current_model}")
try:
# 回滚:恢复环境变量和重新加载旧模型
_restore_env_vars(old_force_int8, old_force_bfloat16)
if current_model:
project_registry.ensure_loaded(current_model)
context.set_current_model(current_model)
print(f"✅ 已回滚到旧模型: {current_model}")
except Exception as rollback_error:
print(f"⚠️ 回滚失败: {rollback_error}")
return {
'success': False,
'message': f'Model switch failed: {str(e)}'
}, 500
finally:
# 无论成功还是失败,都要清除加载标志
context.set_model_loading(False)
gc.collect()
|