File size: 7,310 Bytes
c8f3989
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""模型切换 API"""
import gc
import os
from typing import Optional

import torch
from model_paths import MODEL_PATHS
from backend.model_manager import project_registry
from backend.app_context import get_app_context
from backend.api.utils import require_admin


def get_available_models():
    """获取所有可用的模型列表"""
    return {
        'success': True,
        'models': list(MODEL_PATHS.keys())
    }, 200


def _get_device_type() -> str:
    """获取当前设备类型"""
    if torch.cuda.is_available():
        return "cuda"
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"


def _restore_env_vars(old_force_int8: Optional[str], old_force_bfloat16: Optional[str]) -> None:
    """恢复环境变量配置"""
    if old_force_int8 is not None:
        os.environ['FORCE_INT8'] = old_force_int8
    else:
        os.environ.pop('FORCE_INT8', None)
    
    if old_force_bfloat16 is not None:
        os.environ['CPU_FORCE_BFLOAT16'] = old_force_bfloat16
    else:
        os.environ.pop('CPU_FORCE_BFLOAT16', None)


def get_current_model():
    """获取当前使用的模型及量化配置"""
    # 使用模块级上下文以获取持久化的模型状态
    context = get_app_context(prefer_module_context=True)
    device_type = _get_device_type()
    
    return {
        'success': True,
        'model': context.model_name,
        'loading': context.model_loading,
        'device_type': device_type,
        'use_int8': os.environ.get('FORCE_INT8') == '1',
        'use_bfloat16': os.environ.get('CPU_FORCE_BFLOAT16') == '1'
    }, 200


@require_admin
def switch_model(switch_request):
    """
    切换模型(需要管理员权限)
    
    Args:
        switch_request: 切换请求字典,包含:
            - model: 目标模型名称
            - use_int8: 是否使用 INT8 量化(可选)
            - use_bfloat16: 是否使用 bfloat16(可选,仅CPU)
    
    Returns:
        (响应字典, 状态码) 元组
    """
    target_model = switch_request.get('model')
    use_int8 = switch_request.get('use_int8', False)
    use_bfloat16 = switch_request.get('use_bfloat16', False)
    
    # 验证请求
    if not target_model:
        return {
            'success': False,
            'message': 'Missing model parameter'
        }, 400
    
    # 检查模型是否可用
    if target_model not in MODEL_PATHS:
        available_models = list(MODEL_PATHS.keys())
        return {
            'success': False,
            'message': f'Model {target_model} does not exist. Available models: {", ".join(available_models)}'
        }, 404
    
    # 获取设备类型
    device_type = _get_device_type()
    
    # 验证量化参数与设备兼容性
    if use_int8 and device_type == "mps":
        return {
            'success': False,
            'message': 'INT8 quantization is not supported on MPS device'
        }, 400
    
    if use_bfloat16 and device_type != "cpu":
        return {
            'success': False,
            'message': 'bfloat16 quantization is only supported on CPU device'
        }, 400
    
    if use_int8 and use_bfloat16:
        return {
            'success': False,
            'message': 'Cannot enable both INT8 and bfloat16 quantization'
        }, 400
    
    # 使用模块级上下文以确保状态修改持久化(不会被后续请求重置)
    context = get_app_context(prefer_module_context=True)
    current_model = context.model_name
    
    # 保存当前环境变量配置(用于回滚)
    old_force_int8 = os.environ.get('FORCE_INT8')
    old_force_bfloat16 = os.environ.get('CPU_FORCE_BFLOAT16')
    
    # 检查是否已经是目标模型且量化配置相同
    current_int8 = os.environ.get('FORCE_INT8') == '1'
    current_bfloat16 = os.environ.get('CPU_FORCE_BFLOAT16') == '1'
    
    if (current_model == target_model and 
        current_int8 == use_int8 and 
        current_bfloat16 == use_bfloat16):
        return {
            'success': True,
            'message': f'Already using model {target_model} (same quantization configuration)',
            'model': target_model
        }, 200
    
    # 检查模型是否正在加载中(初始加载或切换)
    if context.model_loading:
        return {
            'success': False,
            'message': 'Model is currently loading, please try again later'
        }, 503
    
    try:
        # 标记开始加载
        context.set_model_loading(True)
        print(f"🔄 开始切换模型: {current_model} -> {target_model}")
        
        # 设置新的量化环境变量
        if use_int8:
            os.environ['FORCE_INT8'] = '1'
            print(f"   设置量化: INT8")
        else:
            os.environ.pop('FORCE_INT8', None)
        
        if use_bfloat16:
            os.environ['CPU_FORCE_BFLOAT16'] = '1'
            print(f"   设置量化: bfloat16")
        else:
            os.environ.pop('CPU_FORCE_BFLOAT16', None)
        
        # 卸载旧模型
        if current_model and current_model in project_registry:
            print(f"   卸载旧模型: {current_model}")
            project_registry.unload(current_model)
            gc.collect()
            if device_type == "cuda":
                torch.cuda.empty_cache()
            elif device_type == "mps":
                torch.mps.empty_cache()
        
        # 加载新模型
        print(f"   加载新模型: {target_model}")
        project_registry.ensure_loaded(target_model)
        
        # 更新当前模型
        context.set_current_model(target_model)
        
        print(f"✅ 模型切换成功: {target_model}")
        
        return {
            'success': True,
            'message': f'Model switched to {target_model}',
            'model': target_model
        }, 200
        
    except KeyError as e:
        # 模型不存在(虽然前面已经检查过,但以防万一)
        print(f"❌ 模型切换失败: 模型 {target_model} 未注册")
        # 回滚:恢复旧模型名称和环境变量
        context.set_current_model(current_model)
        _restore_env_vars(old_force_int8, old_force_bfloat16)
        return {
            'success': False,
            'message': f'Model {target_model} is not registered'
        }, 404
        
    except Exception as e:
        # 加载失败,尝试回滚
        print(f"❌ 模型切换失败: {e}")
        print(f"   尝试回滚到旧模型: {current_model}")
        
        try:
            # 回滚:恢复环境变量和重新加载旧模型
            _restore_env_vars(old_force_int8, old_force_bfloat16)
            if current_model:
                project_registry.ensure_loaded(current_model)
                context.set_current_model(current_model)
                print(f"✅ 已回滚到旧模型: {current_model}")
        except Exception as rollback_error:
            print(f"⚠️  回滚失败: {rollback_error}")
        
        return {
            'success': False,
            'message': f'Model switch failed: {str(e)}'
        }, 500
        
    finally:
        # 无论成功还是失败,都要清除加载标志
        context.set_model_loading(False)
        gc.collect()