Spaces:
Sleeping
Sleeping
File size: 2,314 Bytes
a2cbcac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | from typing import Dict, Any, Union
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from . import config
class ModelFactory:
@staticmethod
def create_model(model_config: Union[Dict[str, Any], str]) -> Union[ChatOpenAI, ChatAnthropic]:
"""
Create LLM model based on provider configuration.
Args:
model_config: Either a dict with provider/model/temperature or a string model name (legacy)
Returns:
Configured LLM instance
"""
# Handle legacy string model names
if isinstance(model_config, str):
return ChatOpenAI(
model=model_config,
temperature=0.7
)
# Handle new dict-based configuration
provider = model_config.get('provider')
model_name = model_config.get('model')
temperature = model_config.get('temperature')
if not model_name:
raise ValueError("Model name is required in configuration")
if provider == 'anthropic':
return ChatAnthropic(
model_name=model_name,
temperature=temperature,
timeout=None,
stop=None
)
elif provider == 'openai':
return ChatOpenAI(
model=model_name,
temperature=temperature
)
else:
raise ValueError(f"Unsupported provider: {provider}. Supported providers: openai, anthropic")
@staticmethod
def get_portfolio_manager_model():
"""Get configured portfolio manager model."""
return ModelFactory.create_model(config.model_portfolio_manager)
@staticmethod
def get_nlp_features_model():
"""Get configured NLP features model."""
return ModelFactory.create_model(config.model_nlp_features)
@staticmethod
def get_assess_significance_model():
"""Get configured assess significance model."""
return ModelFactory.create_model(config.model_assess_significance)
@staticmethod
def get_enhanced_summary_model():
"""Get configured enhanced summary model."""
return ModelFactory.create_model(config.model_enhanced_summary) |