Snap-Solver / models /factory.py
renxsh
init
f1b4581
from typing import Dict, Type, Any, Optional
import json
import os
import importlib
from .base import BaseModel
from .mathpix import MathpixModel # MathpixModel需要直接导入,因为它是特殊OCR工具
from .baidu_ocr import BaiduOCRModel # 百度OCR也是特殊OCR工具,直接导入
class ModelFactory:
# 模型基本信息,包含类型和特性
_models: Dict[str, Dict[str, Any]] = {}
_class_map: Dict[str, Type[BaseModel]] = {}
@classmethod
def initialize(cls):
"""从配置文件加载模型信息"""
try:
config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'models.json')
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# 加载提供商信息和类映射
providers = config.get('providers', {})
for provider_id, provider_info in providers.items():
class_name = provider_info.get('class_name')
if class_name:
# 从当前包动态导入模型类
module = importlib.import_module(f'.{provider_id.lower()}', package=__package__)
cls._class_map[provider_id] = getattr(module, class_name)
# 加载模型信息
for model_id, model_info in config.get('models', {}).items():
provider_id = model_info.get('provider')
if provider_id and provider_id in cls._class_map:
cls._models[model_id] = {
'class': cls._class_map[provider_id],
'provider_id': provider_id,
'is_multimodal': model_info.get('supportsMultimodal', False),
'is_reasoning': model_info.get('isReasoning', False),
'display_name': model_info.get('name', model_id),
'description': model_info.get('description', '')
}
# 添加特殊OCR工具模型(不在配置文件中定义)
# 添加Mathpix OCR工具
cls._models['mathpix'] = {
'class': MathpixModel,
'is_multimodal': True,
'is_reasoning': False,
'display_name': 'Mathpix OCR',
'description': '数学公式识别工具,适用于复杂数学内容',
'is_ocr_only': True
}
# 添加百度OCR工具
cls._models['baidu-ocr'] = {
'class': BaiduOCRModel,
'is_multimodal': True,
'is_reasoning': False,
'display_name': '百度OCR',
'description': '通用文字识别工具,支持中文识别',
'is_ocr_only': True
}
print(f"已从配置加载 {len(cls._models)} 个模型")
except Exception as e:
print(f"加载模型配置失败: {str(e)}")
cls._initialize_defaults()
@classmethod
def _initialize_defaults(cls):
"""初始化默认模型(当配置加载失败时)"""
print("配置加载失败,使用空模型列表")
# 不再硬编码模型定义,而是使用空字典
cls._models = {}
# 添加特殊OCR工具(当配置加载失败时的备用)
try:
# 导入并添加Mathpix OCR工具
from .mathpix import MathpixModel
cls._models['mathpix'] = {
'class': MathpixModel,
'is_multimodal': True,
'is_reasoning': False,
'display_name': 'Mathpix OCR',
'description': '数学公式识别工具,适用于复杂数学内容',
'is_ocr_only': True
}
except Exception as e:
print(f"无法加载Mathpix OCR工具: {str(e)}")
# 添加百度OCR工具
try:
from .baidu_ocr import BaiduOCRModel
cls._models['baidu-ocr'] = {
'class': BaiduOCRModel,
'is_multimodal': True,
'is_reasoning': False,
'display_name': '百度OCR',
'description': '通用文字识别工具,支持中文识别',
'is_ocr_only': True
}
except Exception as e:
print(f"无法加载百度OCR工具: {str(e)}")
@classmethod
def create_model(cls, model_name: str, api_key: str, temperature: float = 0.7,
system_prompt: Optional[str] = None, language: Optional[str] = None, api_base_url: Optional[str] = None) -> BaseModel:
"""
Create a model instance based on the model name.
Args:
model_name: The identifier for the model
api_key: The API key for the model service
temperature: The temperature to use for generation
system_prompt: The system prompt to use
language: The preferred language for responses
api_base_url: The base URL for API requests
Returns:
A model instance
"""
if model_name not in cls._models:
raise ValueError(f"Unknown model: {model_name}")
model_info = cls._models[model_name]
model_class = model_info['class']
provider_id = model_info.get('provider_id')
if provider_id == 'openai':
return model_class(
api_key=api_key,
temperature=temperature,
system_prompt=system_prompt,
language=language,
api_base_url=api_base_url,
model_identifier=model_name
)
# 对于DeepSeek模型,需要传递正确的模型名称
if 'deepseek' in model_name.lower():
return model_class(
api_key=api_key,
temperature=temperature,
system_prompt=system_prompt,
language=language,
model_name=model_name,
api_base_url=api_base_url
)
# 对于阿里巴巴模型,也需要传递正确的模型名称
elif 'qwen' in model_name.lower() or 'qvq' in model_name.lower() or 'alibaba' in model_name.lower():
return model_class(
api_key=api_key,
temperature=temperature,
system_prompt=system_prompt,
language=language,
model_name=model_name
)
# 对于Google模型,也需要传递正确的模型名称
elif 'gemini' in model_name.lower() or 'google' in model_name.lower():
return model_class(
api_key=api_key,
temperature=temperature,
system_prompt=system_prompt,
language=language,
model_name=model_name,
api_base_url=api_base_url
)
# 对于豆包模型,也需要传递正确的模型名称
elif 'doubao' in model_name.lower():
return model_class(
api_key=api_key,
temperature=temperature,
system_prompt=system_prompt,
language=language,
model_name=model_name,
api_base_url=api_base_url
)
# 对于Mathpix模型,不传递language参数
elif model_name == 'mathpix':
return model_class(
api_key=api_key,
temperature=temperature,
system_prompt=system_prompt
)
# 对于百度OCR模型,传递api_key(支持API_KEY:SECRET_KEY格式)
elif model_name == 'baidu-ocr':
return model_class(
api_key=api_key,
temperature=temperature,
system_prompt=system_prompt
)
# 对于Anthropic模型,需要传递model_identifier参数
elif 'claude' in model_name.lower() or 'anthropic' in model_name.lower():
return model_class(
api_key=api_key,
temperature=temperature,
system_prompt=system_prompt,
language=language,
api_base_url=api_base_url,
model_identifier=model_name
)
else:
# 其他模型仅传递标准参数
return model_class(
api_key=api_key,
temperature=temperature,
system_prompt=system_prompt,
language=language,
api_base_url=api_base_url
)
@classmethod
def get_available_models(cls) -> list[Dict[str, Any]]:
"""Return a list of available models with their information"""
models_info = []
for model_id, info in cls._models.items():
# 跳过仅OCR工具模型
if info.get('is_ocr_only', False):
continue
models_info.append({
'id': model_id,
'display_name': info.get('display_name', model_id),
'description': info.get('description', ''),
'is_multimodal': info.get('is_multimodal', False),
'is_reasoning': info.get('is_reasoning', False)
})
return models_info
@classmethod
def get_model_ids(cls) -> list[str]:
"""Return a list of available model identifiers"""
return [model_id for model_id in cls._models.keys()
if not cls._models[model_id].get('is_ocr_only', False)]
@classmethod
def is_multimodal(cls, model_name: str) -> bool:
"""判断模型是否支持多模态输入"""
return cls._models.get(model_name, {}).get('is_multimodal', False)
@classmethod
def is_reasoning(cls, model_name: str) -> bool:
"""判断模型是否为推理模型"""
return cls._models.get(model_name, {}).get('is_reasoning', False)
@classmethod
def get_model_display_name(cls, model_name: str) -> str:
"""获取模型的显示名称"""
return cls._models.get(model_name, {}).get('display_name', model_name)
@classmethod
def register_model(cls, model_name: str, model_class: Type[BaseModel],
is_multimodal: bool = False, is_reasoning: bool = False,
display_name: Optional[str] = None, description: Optional[str] = None) -> None:
"""
Register a new model type with the factory.
Args:
model_name: The identifier for the model
model_class: The model class to register
is_multimodal: Whether the model supports image input
is_reasoning: Whether the model provides reasoning process
display_name: Human-readable name for the model
description: Description of the model
"""
cls._models[model_name] = {
'class': model_class,
'is_multimodal': is_multimodal,
'is_reasoning': is_reasoning,
'display_name': display_name or model_name,
'description': description or ''
}