File size: 6,508 Bytes
461adca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492b465
 
a765e3e
 
492b465
461adca
 
 
 
 
a765e3e
 
 
 
461adca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
Dynamic model enums generated from YAML configuration.
This module provides backward compatibility while using YAML as single source of truth.
"""

from enum import Enum
from typing import Dict, List, Optional
from .loader import ConfigLoader


class ModelRegistry:
    """Registry for dynamically generated model enums from YAML."""
    
    _instance = None
    _generation_models: Dict[str, str] = {}
    _analysis_models: Dict[str, str] = {}
    _default_generation_model: Optional[str] = None
    _default_analysis_model: Optional[str] = None
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._load_models()
        return cls._instance
    
    def _load_models(self):
        """Load models from YAML configuration."""
        loader = ConfigLoader()
        settings = loader.load_config(validate_api_keys=False)
        
        # Load generation models
        generation_config = settings.models.generation
        for provider in ['openai', 'anthropic', 'gemini', 'deepseek']:
            model_list = getattr(generation_config, provider, [])
            for model in model_list:
                model_name = model.name
                # Create enum-friendly key from model name
                enum_key = self._create_enum_key(model_name, provider)
                self._generation_models[enum_key] = model_name
                
                # Track default model
                if model.default:
                    self._default_generation_model = model_name
        
        # Load analysis models
        analysis_config = settings.models.analysis
        for provider in ['openai', 'anthropic', 'gemini', 'deepseek']:
            model_list = getattr(analysis_config, provider, [])
            for model in model_list:
                model_name = model.name
                # Create enum-friendly key from model name
                enum_key = self._create_enum_key(model_name, provider)
                self._analysis_models[enum_key] = model_name
                
                # Track default model
                if model.default:
                    self._default_analysis_model = model_name
    
    @staticmethod
    def _create_enum_key(model_name: str, provider: str) -> str:
        """Create enum-friendly key from model name."""
        # Handle fine-tuned models
        if model_name.startswith('ft:'):
            if 'lp-1700-part-cd-120' in model_name:
                return 'GPT4o_MINI_LP'
            elif 'legal-position-1700' in model_name:
                return 'GPT4o_LP'
            else:
                # Generic fine-tuned model
                return 'GPT4o_FT'
        
        if model_name == 'gpt-5.2':
            return 'GPT5_2'
        elif model_name == 'gpt-5-mini':
            return 'GPT5_MINI'
        elif model_name == 'gpt-4.1':
            return 'GPT4_1'
        elif model_name == 'gpt-4o':
            return 'GPT4o'
        elif model_name == 'gpt-4o-mini':
            return 'GPT4o_MINI'
        elif model_name == 'claude-opus-4-6':
            return 'CLAUDE_OPUS_4_6'
        elif model_name == 'claude-sonnet-4-6':
            return 'CLAUDE_SONNET_4_6'
        elif model_name == 'claude-haiku-4-5-20251001':
            return 'CLAUDE_HAIKU_4_5'
        elif model_name == 'gemini-3-flash-preview':
            return 'GEMINI_3_FLASH'
        elif model_name == 'gemini-3-pro-preview':
            return 'GEMINI_3_PRO'
        elif model_name == 'deepseek-chat':
            return 'DEEPSEEK_CHAT'
        elif model_name == 'deepseek-reasoner':
            return 'DEEPSEEK_REASONER'
        else:
            # Fallback: convert to uppercase and replace hyphens
            return model_name.upper().replace('-', '_').replace('.', '_')
    
    def get_generation_models(self) -> Dict[str, str]:
        """Get all generation models."""
        return self._generation_models.copy()
    
    def get_analysis_models(self) -> Dict[str, str]:
        """Get all analysis models."""
        return self._analysis_models.copy()
    
    def get_default_generation_model(self) -> Optional[str]:
        """Get default generation model."""
        return self._default_generation_model
    
    def get_default_analysis_model(self) -> Optional[str]:
        """Get default analysis model."""
        return self._default_analysis_model
    
    def get_models_by_provider(self, provider: str, model_type: str = 'generation') -> List[str]:
        """Get models for a specific provider."""
        loader = ConfigLoader()
        settings = loader.load_config(validate_api_keys=False)
        
        if model_type == 'generation':
            provider_models = getattr(settings.models.generation, provider, [])
        else:
            provider_models = getattr(settings.models.analysis, provider, [])
        
        return [model.name for model in provider_models]


# Create singleton instance
_registry = ModelRegistry()

# Dynamically create GenerationModelName enum
GenerationModelName = Enum(
    'GenerationModelName',
    _registry.get_generation_models(),
    type=str
)

# Dynamically create AnalysisModelName enum
AnalysisModelName = Enum(
    'AnalysisModelName',
    _registry.get_analysis_models(),
    type=str
)

# Default models
DEFAULT_GENERATION_MODEL = None
DEFAULT_ANALYSIS_MODEL = None

# Set defaults after enum creation
_default_gen = _registry.get_default_generation_model()
_default_ana = _registry.get_default_analysis_model()

if _default_gen:
    for member in GenerationModelName:
        if member.value == _default_gen:
            DEFAULT_GENERATION_MODEL = member
            break

if _default_ana:
    for member in AnalysisModelName:
        if member.value == _default_ana:
            DEFAULT_ANALYSIS_MODEL = member
            break


# Helper functions for backward compatibility
def get_generation_models_by_provider(provider: str) -> List[str]:
    """Get generation models for a specific provider."""
    return _registry.get_models_by_provider(provider, 'generation')


def get_analysis_models_by_provider(provider: str) -> List[str]:
    """Get analysis models for a specific provider."""
    return _registry.get_models_by_provider(provider, 'analysis')


__all__ = [
    'GenerationModelName',
    'AnalysisModelName',
    'DEFAULT_GENERATION_MODEL',
    'DEFAULT_ANALYSIS_MODEL',
    'ModelRegistry',
    'get_generation_models_by_provider',
    'get_analysis_models_by_provider',
]