userIdc2024 commited on
Commit
846ac2c
·
verified ·
1 Parent(s): af32596

Create multimodel_services/model_manager.py

Browse files
multimodel_services/model_manager.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Dict, Any, Optional
5
+
6
+ def load_model_configs() -> Dict[str, Any]:
7
+ """Load all model configurations from JSON file"""
8
+ config_path = Path(__file__).parent.parent / "multimodel_configs" / "models.json"
9
+ with open(config_path, 'r') as f:
10
+ return json.load(f)
11
+
12
+ def get_image_to_image_models() -> Dict[str, Any]:
13
+ """Get only models that support image input"""
14
+ models = load_model_configs()
15
+ return {k: v for k, v in models.items() if v.get('supports_image_input', False)}
16
+
17
+ def get_ui_parameters(model_name: str) -> Dict[str, Any]:
18
+ """Get only parameters that should be shown in UI"""
19
+ models = load_model_configs()
20
+ if model_name not in models:
21
+ return {}
22
+
23
+ ui_params = {}
24
+ for param_name, param_config in models[model_name].get('parameters', {}).items():
25
+ if param_config.get('show_in_ui', True):
26
+ ui_params[param_name] = param_config
27
+
28
+ return ui_params
29
+
30
+ def get_all_parameters(model_name: str, ui_params: Optional[Dict] = None) -> Dict[str, Any]:
31
+ """Get all parameters with defaults for API call"""
32
+ models = load_model_configs()
33
+ if model_name not in models:
34
+ return {}
35
+
36
+ all_params = {}
37
+ for param_name, param_config in models[model_name].get('parameters', {}).items():
38
+ if ui_params and param_name in ui_params:
39
+ # Use user-selected value
40
+ all_params[param_name] = ui_params[param_name]
41
+ else:
42
+ # Use default value
43
+ all_params[param_name] = param_config.get('default')
44
+
45
+ return all_params
46
+
47
+ def get_model_display_name(model_name: str) -> str:
48
+ """Get display name for model"""
49
+ models = load_model_configs()
50
+ return models.get(model_name, {}).get('display_name', model_name)
51
+
52
+ def is_gpt_model(model_name: str) -> bool:
53
+ """Check if model is GPT default"""
54
+ return model_name == "gpt_default"