Rulga commited on
Commit
cd4842a
·
1 Parent(s): 5e4f0dd

Implement ModelManager class for model registration, listing, activation, and deletion

Browse files
Files changed (1) hide show
  1. src/training/model_manager.py +154 -5
src/training/model_manager.py CHANGED
@@ -17,6 +17,158 @@ logging.basicConfig(
17
  )
18
  logger = logging.getLogger(__name__)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def get_model(
21
  version: Optional[str] = None,
22
  device: str = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
@@ -31,12 +183,9 @@ def get_model(
31
  Returns:
32
  (model, tokenizer, model_info)
33
  """
34
- manager = ModelManager()
35
-
36
- # Use base model if version is None
37
- model_path = MODEL_CONFIG["training"]["fine_tuned_path"] if version else MODEL_CONFIG["training"]["base_model_path"]
38
-
39
  try:
 
 
40
  tokenizer = AutoTokenizer.from_pretrained(
41
  model_path,
42
  trust_remote_code=True
 
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
+ class ModelManager:
21
+ def __init__(self):
22
+ """Initialize model manager"""
23
+ self.registry_path = MODELS_REGISTRY_PATH
24
+ os.makedirs(os.path.dirname(self.registry_path), exist_ok=True)
25
+
26
+ # Create empty registry if it doesn't exist
27
+ if not os.path.exists(self.registry_path):
28
+ self._save_registry([])
29
+
30
+ def _load_registry(self) -> List[Dict[str, Any]]:
31
+ """Load models registry"""
32
+ try:
33
+ if os.path.exists(self.registry_path):
34
+ with open(self.registry_path, 'r') as f:
35
+ return json.load(f)
36
+ return []
37
+ except Exception as e:
38
+ logger.error(f"Error loading registry: {str(e)}")
39
+ return []
40
+
41
+ def _save_registry(self, registry: List[Dict[str, Any]]) -> bool:
42
+ """Save models registry"""
43
+ try:
44
+ with open(self.registry_path, 'w') as f:
45
+ json.dump(registry, f, indent=2)
46
+ return True
47
+ except Exception as e:
48
+ logger.error(f"Error saving registry: {str(e)}")
49
+ return False
50
+
51
+ def register_model(
52
+ self,
53
+ model_id: str,
54
+ version: str,
55
+ source: str,
56
+ description: str = "",
57
+ is_active: bool = False
58
+ ) -> Tuple[bool, str]:
59
+ """
60
+ Register new model version
61
+
62
+ Args:
63
+ model_id: Model identifier
64
+ version: Model version
65
+ source: Model source (local path or Hugging Face model id)
66
+ description: Model description
67
+ is_active: Set as active model
68
+
69
+ Returns:
70
+ (success, message)
71
+ """
72
+ try:
73
+ registry = self._load_registry()
74
+
75
+ # Check if model version already exists
76
+ for model in registry:
77
+ if model["model_id"] == model_id and model["version"] == version:
78
+ return False, f"Model {model_id} version {version} already exists"
79
+
80
+ # Add new model
81
+ registry.append({
82
+ "model_id": model_id,
83
+ "version": version,
84
+ "source": source,
85
+ "description": description,
86
+ "is_active": is_active,
87
+ "registration_date": datetime.now().isoformat()
88
+ })
89
+
90
+ # If this model is set as active, deactivate others
91
+ if is_active:
92
+ for model in registry[:-1]: # Skip the last one (just added)
93
+ if model["model_id"] == model_id:
94
+ model["is_active"] = False
95
+
96
+ # Save registry
97
+ if self._save_registry(registry):
98
+ return True, f"Model {model_id} version {version} registered successfully"
99
+ return False, "Failed to save registry"
100
+
101
+ except Exception as e:
102
+ return False, f"Error registering model: {str(e)}"
103
+
104
+ def list_models(self) -> List[Dict[str, Any]]:
105
+ """Get list of registered models"""
106
+ return self._load_registry()
107
+
108
+ def set_active_model(self, model_id: str, version: str) -> Tuple[bool, str]:
109
+ """
110
+ Set model version as active
111
+
112
+ Args:
113
+ model_id: Model identifier
114
+ version: Model version
115
+
116
+ Returns:
117
+ (success, message)
118
+ """
119
+ try:
120
+ registry = self._load_registry()
121
+ model_found = False
122
+
123
+ # Update active status
124
+ for model in registry:
125
+ if model["model_id"] == model_id:
126
+ model["is_active"] = (model["version"] == version)
127
+ if model["version"] == version:
128
+ model_found = True
129
+
130
+ if not model_found:
131
+ return False, f"Model {model_id} version {version} not found"
132
+
133
+ # Save registry
134
+ if self._save_registry(registry):
135
+ return True, f"Model {model_id} version {version} set as active"
136
+ return False, "Failed to save registry"
137
+
138
+ except Exception as e:
139
+ return False, f"Error setting active model: {str(e)}"
140
+
141
+ def delete_model(self, model_id: str, version: str) -> Tuple[bool, str]:
142
+ """
143
+ Delete model version
144
+
145
+ Args:
146
+ model_id: Model identifier
147
+ version: Model version
148
+
149
+ Returns:
150
+ (success, message)
151
+ """
152
+ try:
153
+ registry = self._load_registry()
154
+
155
+ # Find and remove model
156
+ for i, model in enumerate(registry):
157
+ if model["model_id"] == model_id and model["version"] == version:
158
+ if model["is_active"]:
159
+ return False, "Cannot delete active model"
160
+ registry.pop(i)
161
+
162
+ # Save registry
163
+ if self._save_registry(registry):
164
+ return True, f"Model {model_id} version {version} deleted"
165
+ return False, "Failed to save registry"
166
+
167
+ return False, f"Model {model_id} version {version} not found"
168
+
169
+ except Exception as e:
170
+ return False, f"Error deleting model: {str(e)}"
171
+
172
  def get_model(
173
  version: Optional[str] = None,
174
  device: str = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
 
183
  Returns:
184
  (model, tokenizer, model_info)
185
  """
 
 
 
 
 
186
  try:
187
+ model_path = MODEL_CONFIG["training"]["fine_tuned_path"] if version else MODEL_CONFIG["training"]["base_model_path"]
188
+
189
  tokenizer = AutoTokenizer.from_pretrained(
190
  model_path,
191
  trust_remote_code=True