Rulga commited on
Commit
1ca2b19
·
1 Parent(s): 7581420

unification of the model and output of its choice into the configuration

Browse files
Files changed (3) hide show
  1. app.py +72 -1
  2. config/settings.py +33 -34
  3. src/training/model_manager.py +27 -483
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import os
3
  from huggingface_hub import InferenceClient
4
  from config.constants import DEFAULT_SYSTEM_MESSAGE
5
- from config.settings import DEFAULT_MODEL, HF_TOKEN
6
  from src.knowledge_base.vector_store import create_vector_store, load_vector_store
7
  from web.training_interface import (
8
  get_models_df,
@@ -233,6 +233,77 @@ with gr.Blocks() as demo:
233
  build_kb_btn.click(build_kb, None, kb_status)
234
  clear_btn.click(lambda: ([], None), None, [chatbot, conversation_id])
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  with gr.Tab("Model Training"):
237
  gr.Markdown("### Model Training Interface")
238
 
 
2
  import os
3
  from huggingface_hub import InferenceClient
4
  from config.constants import DEFAULT_SYSTEM_MESSAGE
5
+ from config.settings import DEFAULT_MODEL, HF_TOKEN, MODEL_CONFIG
6
  from src.knowledge_base.vector_store import create_vector_store, load_vector_store
7
  from web.training_interface import (
8
  get_models_df,
 
233
  build_kb_btn.click(build_kb, None, kb_status)
234
  clear_btn.click(lambda: ([], None), None, [chatbot, conversation_id])
235
 
236
+ with gr.Tab("Model Settings"):
237
+ gr.Markdown("### Model Configuration")
238
+
239
+ with gr.Row():
240
+ with gr.Column(scale=2):
241
+ # Model Information
242
+ gr.Markdown(f"""
243
+ **Current Model:** {MODEL_CONFIG['name']}
244
+
245
+ **Model ID:** `{MODEL_CONFIG['id']}`
246
+
247
+ **Description:** {MODEL_CONFIG['description']}
248
+
249
+ **Type:** {MODEL_CONFIG['type']}
250
+ """)
251
+
252
+ gr.Markdown("### Model Parameters")
253
+ with gr.Row():
254
+ max_length = gr.Slider(
255
+ minimum=1,
256
+ maximum=4096,
257
+ value=MODEL_CONFIG['parameters']['max_length'],
258
+ step=1,
259
+ label="Maximum Length",
260
+ interactive=False
261
+ )
262
+ temperature = gr.Slider(
263
+ minimum=0.1,
264
+ maximum=2.0,
265
+ value=MODEL_CONFIG['parameters']['temperature'],
266
+ step=0.1,
267
+ label="Temperature",
268
+ interactive=False
269
+ )
270
+ with gr.Row():
271
+ top_p = gr.Slider(
272
+ minimum=0.1,
273
+ maximum=1.0,
274
+ value=MODEL_CONFIG['parameters']['top_p'],
275
+ step=0.1,
276
+ label="Top-p",
277
+ interactive=False
278
+ )
279
+ rep_penalty = gr.Slider(
280
+ minimum=1.0,
281
+ maximum=2.0,
282
+ value=MODEL_CONFIG['parameters']['repetition_penalty'],
283
+ step=0.1,
284
+ label="Repetition Penalty",
285
+ interactive=False
286
+ )
287
+
288
+ with gr.Column(scale=1):
289
+ gr.Markdown("### Training Configuration")
290
+ gr.Markdown(f"""
291
+ **Base Model Path:**
292
+ ```
293
+ {MODEL_CONFIG['training']['base_model_path']}
294
+ ```
295
+
296
+ **Fine-tuned Model Path:**
297
+ ```
298
+ {MODEL_CONFIG['training']['fine_tuned_path']}
299
+ ```
300
+
301
+ **LoRA Configuration:**
302
+ - Rank (r): {MODEL_CONFIG['training']['lora_config']['r']}
303
+ - Alpha: {MODEL_CONFIG['training']['lora_config']['lora_alpha']}
304
+ - Dropout: {MODEL_CONFIG['training']['lora_config']['lora_dropout']}
305
+ """)
306
+
307
  with gr.Tab("Model Training"):
308
  gr.Markdown("### Model Training Interface")
309
 
config/settings.py CHANGED
@@ -1,42 +1,41 @@
1
  import os
2
- from dotenv import load_dotenv
3
-
4
- # Debug information
5
- #print("Current directory:", os.getcwd())
6
- env_path = os.path.join(os.getcwd(), '.env')
7
- #print("Path to .env:", env_path)
8
- #print(".env file exists:", os.path.exists(env_path))
9
-
10
- if os.path.exists(env_path):
11
- with open(env_path, 'r') as f:
12
- print("Contents of .env file:", f.read())
13
-
14
- # Load environment variables
15
- load_dotenv(verbose=True)
16
-
17
- # Directory paths
18
- BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
19
- VECTOR_STORE_PATH = os.path.join(BASE_DIR, "data", "vector_store")
20
-
21
- # Добавляем недостающие пути для обучения моделей
22
- MODEL_PATH = os.path.join(BASE_DIR, "models")
23
- TRAINING_OUTPUT_DIR = os.path.join(BASE_DIR, "models", "trained")
24
- MODELS_REGISTRY_PATH = os.path.join(BASE_DIR, "data", "models_registry.json")
25
-
26
- # Create directories if they don't exist
27
- os.makedirs(VECTOR_STORE_PATH, exist_ok=True)
28
- os.makedirs(MODEL_PATH, exist_ok=True)
29
- os.makedirs(TRAINING_OUTPUT_DIR, exist_ok=True)
30
- os.makedirs(os.path.dirname(MODELS_REGISTRY_PATH), exist_ok=True)
31
-
32
- # Model settings
33
- EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
34
- DEFAULT_MODEL = "HuggingFaceH4/zephyr-7b-beta"
35
 
36
  # API tokens
37
  HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
38
  if not HF_TOKEN:
39
  raise ValueError("HUGGINGFACE_TOKEN not found in environment variables")
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Request settings
42
- USER_AGENT = "Status-Law-Assistant/1.0"
 
1
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  # API tokens
4
  HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
5
  if not HF_TOKEN:
6
  raise ValueError("HUGGINGFACE_TOKEN not found in environment variables")
7
 
8
+ # Paths configuration
9
+ MODEL_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models")
10
+ TRAINING_OUTPUT_DIR = os.path.join(MODEL_PATH, "fine_tuned")
11
+ VECTOR_STORE_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store")
12
+
13
+ # Model configuration
14
+ MODEL_CONFIG = {
15
+ "id": "HuggingFaceH4/zephyr-7b-beta",
16
+ "name": "Zephyr 7B",
17
+ "description": "A state-of-the-art 7B parameter language model",
18
+ "type": "base", # base/fine-tuned
19
+ "parameters": {
20
+ "max_length": 2048,
21
+ "temperature": 0.7,
22
+ "top_p": 0.9,
23
+ "repetition_penalty": 1.1,
24
+ },
25
+ "training": {
26
+ "base_model_path": os.path.join(MODEL_PATH, "zephyr-7b-beta"),
27
+ "fine_tuned_path": os.path.join(TRAINING_OUTPUT_DIR, "zephyr-7b-beta-tuned"),
28
+ "lora_config": {
29
+ "r": 16,
30
+ "lora_alpha": 32,
31
+ "lora_dropout": 0.05,
32
+ "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"]
33
+ }
34
+ }
35
+ }
36
+
37
+ # Embedding model for vector store
38
+ EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
39
+
40
  # Request settings
41
+ USER_AGENT = "Status-Law-Assistant/1.0"
src/training/model_manager.py CHANGED
@@ -1,514 +1,58 @@
1
  """
2
- Модуль для управления моделями и их версиями
3
  """
4
 
5
  import os
6
  import json
7
- import shutil
8
  from datetime import datetime
9
  from typing import List, Dict, Any, Tuple, Optional
10
  import logging
11
- from huggingface_hub import HfApi, snapshot_download, hf_hub_download
12
  from transformers import AutoModelForCausalLM, AutoTokenizer
13
- from config.settings import MODEL_PATH, MODELS_REGISTRY_PATH
14
 
15
- # Настройка логирования
16
  logging.basicConfig(
17
  level=logging.INFO,
18
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
19
  )
20
  logger = logging.getLogger(__name__)
21
 
22
- class ModelManager:
23
- def __init__(self, registry_path: Optional[str] = None):
24
- """
25
- Инициализация менеджера моделей
26
-
27
- Args:
28
- registry_path: Путь к реестру моделей
29
- """
30
- self.registry_path = registry_path or MODELS_REGISTRY_PATH
31
- self.models_dir = MODEL_PATH
32
-
33
- # Создаем директории, если их нет
34
- os.makedirs(self.registry_path, exist_ok=True)
35
- os.makedirs(self.models_dir, exist_ok=True)
36
-
37
- # Путь к файлу реестра
38
- self.registry_file = os.path.join(self.registry_path, "models_registry.json")
39
-
40
- # Загружаем реестр или создаем новый
41
- self.load_registry()
42
-
43
- def load_registry(self):
44
- """
45
- Загрузка реестра моделей
46
- """
47
- if os.path.exists(self.registry_file):
48
- try:
49
- with open(self.registry_file, "r", encoding="utf-8") as f:
50
- self.registry = json.load(f)
51
- except Exception as e:
52
- logger.error(f"Ошибка загрузки реестра моделей: {str(e)}")
53
- self.registry = {"models": []}
54
- else:
55
- self.registry = {"models": []}
56
-
57
- def save_registry(self):
58
- """
59
- Сохранение реестра моделей
60
- """
61
- try:
62
- with open(self.registry_file, "w", encoding="utf-8") as f:
63
- json.dump(self.registry, f, ensure_ascii=False, indent=2)
64
- except Exception as e:
65
- logger.error(f"Ошибка сохранения реестра моделей: {str(e)}")
66
-
67
- def register_model(
68
- self,
69
- model_id: str,
70
- version: str,
71
- source: str,
72
- description: str = "",
73
- metrics: Optional[Dict[str, Any]] = None,
74
- is_active: bool = False
75
- ) -> Tuple[bool, str]:
76
- """
77
- Регистрация модели в реестре
78
-
79
- Args:
80
- model_id: Идентификатор модели (например, 'saiga_7b_lora')
81
- version: Версия модели
82
- source: Источник модели (например, URL или локальный путь)
83
- description: Описание модели
84
- metrics: Метрики качества модели
85
- is_active: Флаг активности модели
86
-
87
- Returns:
88
- (успех, сообщение)
89
- """
90
- try:
91
- # Создаем запись о модели
92
- model_entry = {
93
- "model_id": model_id,
94
- "version": version,
95
- "source": source,
96
- "description": description,
97
- "metrics": metrics or {},
98
- "is_active": is_active,
99
- "registration_date": datetime.now().isoformat(),
100
- "local_path": os.path.join(self.models_dir, f"{model_id}_{version}")
101
- }
102
-
103
- # Проверяем, есть ли уже такая модель в реестре
104
- for i, model in enumerate(self.registry["models"]):
105
- if model["model_id"] == model_id and model["version"] == version:
106
- # Обновляем существующую запись
107
- self.registry["models"][i] = model_entry
108
- self.save_registry()
109
- return True, f"Модель {model_id} версии {version} обновлена в реестре"
110
-
111
- # Если модель новая, добавляем ее в реестр
112
- self.registry["models"].append(model_entry)
113
-
114
- # Если модель отмечена как активная, деактивируем все другие модели с тем же model_id
115
- if is_active:
116
- for i, model in enumerate(self.registry["models"]):
117
- if model["model_id"] == model_id and model["version"] != version:
118
- self.registry["models"][i]["is_active"] = False
119
-
120
- self.save_registry()
121
- return True, f"Модель {model_id} версии {version} добавлена в реестр"
122
- except Exception as e:
123
- return False, f"Ошибка при регистрации модели: {str(e)}"
124
-
125
- def download_model(
126
- self,
127
- model_id: str,
128
- version: str,
129
- token: Optional[str] = None
130
- ) -> Tuple[bool, str]:
131
- """
132
- Загрузка модели из Hugging Face Hub
133
-
134
- Args:
135
- model_id: Идентификатор модели
136
- version: Версия модели
137
- token: Токен доступа к Hugging Face Hub
138
-
139
- Returns:
140
- (успех, сообщение)
141
- """
142
- try:
143
- # Находим модель в реестре
144
- model_entry = None
145
- for model in self.registry["models"]:
146
- if model["model_id"] == model_id and model["version"] == version:
147
- model_entry = model
148
- break
149
-
150
- if model_entry is None:
151
- return False, f"Модель {model_id} версии {version} не найдена в реестре"
152
-
153
- # Проверяем, что источник - это репозиторий Hugging Face
154
- if not model_entry["source"].startswith("hf://"):
155
- return False, "Источник модели не является репозиторием Hugging Face"
156
-
157
- # Извлекаем имя репозитория
158
- repo_id = model_entry["source"][5:]
159
-
160
- # Путь для сохранения модели
161
- local_path = model_entry["local_path"]
162
-
163
- # Проверяем, существует ли уже директория с моделью
164
- if os.path.exists(local_path):
165
- # Если директория существует, проверяем наличие файлов модели
166
- if os.path.exists(os.path.join(local_path, "pytorch_model.bin")) or \
167
- os.path.exists(os.path.join(local_path, "adapter_model.bin")):
168
- return True, f"Модель {model_id} версии {version} уже загружена"
169
- else:
170
- # Создаем директорию для модели
171
- os.makedirs(local_path, exist_ok=True)
172
-
173
- # Загружаем модель
174
- logger.info(f"Загрузка модели {repo_id} в {local_path}...")
175
- snapshot_download(
176
- repo_id=repo_id,
177
- local_dir=local_path,
178
- token=token
179
- )
180
-
181
- return True, f"Модель {model_id} версии {version} успешно загружена"
182
- except Exception as e:
183
- return False, f"Ошибка при загрузке модели: {str(e)}"
184
-
185
- def get_active_model(self, model_id: str) -> Optional[Dict[str, Any]]:
186
- """
187
- Получение активной версии модели
188
-
189
- Args:
190
- model_id: Идентификатор модели
191
-
192
- Returns:
193
- Словарь с информацией о модели или None, если модель не найдена
194
- """
195
- for model in self.registry["models"]:
196
- if model["model_id"] == model_id and model.get("is_active", False):
197
- return model
198
- return None
199
-
200
- def set_active_model(self, model_id: str, version: str) -> Tuple[bool, str]:
201
- """
202
- Установка активной версии модели
203
-
204
- Args:
205
- model_id: Идентификатор модели
206
- version: Версия модели
207
-
208
- Returns:
209
- (успех, сообщение)
210
- """
211
- try:
212
- # Проверяем, есть ли модель в реестре
213
- model_found = False
214
- for i, model in enumerate(self.registry["models"]):
215
- if model["model_id"] == model_id:
216
- if model["version"] == version:
217
- model_found = True
218
- self.registry["models"][i]["is_active"] = True
219
- else:
220
- self.registry["models"][i]["is_active"] = False
221
-
222
- if not model_found:
223
- return False, f"Модель {model_id} версии {version} не найдена в реестре"
224
-
225
- self.save_registry()
226
- return True, f"Модель {model_id} версии {version} установлена как активная"
227
- except Exception as e:
228
- return False, f"Ош��бка при установке активной модели: {str(e)}"
229
-
230
- def load_model(
231
- self,
232
- model_id: str,
233
- version: Optional[str] = None,
234
- device: str = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
235
- ) -> Tuple[bool, Any, Any, str]:
236
- """
237
- Загрузка модели и токенизатора
238
-
239
- Args:
240
- model_id: Идентификатор модели
241
- version: Версия модели (если None, загружается активная версия)
242
- device: Устройство для загрузки модели
243
-
244
- Returns:
245
- (успех, модель, токенизатор, сообщение)
246
- """
247
- try:
248
- # Определяем версию модели
249
- if version is None:
250
- model_entry = self.get_active_model(model_id)
251
- if model_entry is None:
252
- return False, None, None, f"Активная версия модели {model_id} не найдена"
253
- else:
254
- model_entry = None
255
- for model in self.registry["models"]:
256
- if model["model_id"] == model_id and model["version"] == version:
257
- model_entry = model
258
- break
259
-
260
- if model_entry is None:
261
- return False, None, None, f"Модель {model_id} версии {version} не найдена в реестре"
262
-
263
- # Проверяем, загружена ли модель локально
264
- local_path = model_entry["local_path"]
265
- if not os.path.exists(local_path) or \
266
- (not os.path.exists(os.path.join(local_path, "pytorch_model.bin")) and \
267
- not os.path.exists(os.path.join(local_path, "adapter_model.bin"))):
268
- # Если модель не загружена, пытаемся загрузить её
269
- success, message = self.download_model(model_id, model_entry["version"])
270
- if not success:
271
- return False, None, None, message
272
-
273
- # Загружаем токенизатор
274
- logger.info(f"Загрузка токенизатора из {local_path}...")
275
- tokenizer = AutoTokenizer.from_pretrained(
276
- local_path,
277
- trust_remote_code=True
278
- )
279
-
280
- # Загружаем модель
281
- logger.info(f"Загрузка модели из {local_path}...")
282
- model = AutoModelForCausalLM.from_pretrained(
283
- local_path,
284
- trust_remote_code=True,
285
- device_map="auto" if device == "cuda" else None
286
- )
287
-
288
- return True, model, tokenizer, f"Модель {model_id} версии {model_entry['version']} успешно загружена"
289
- except Exception as e:
290
- return False, None, None, f"Ошибка при загрузке модели: {str(e)}"
291
-
292
- def delete_model(self, model_id: str, version: str) -> Tuple[bool, str]:
293
- """
294
- Удаление модели из реестра и локального хранилища
295
-
296
- Args:
297
- model_id: Идентификатор модели
298
- version: Версия модели
299
-
300
- Returns:
301
- (успех, сообщение)
302
- """
303
- try:
304
- # Ищем модель в реестре
305
- model_entry = None
306
- model_index = -1
307
- for i, model in enumerate(self.registry["models"]):
308
- if model["model_id"] == model_id and model["version"] == version:
309
- model_entry = model
310
- model_index = i
311
- break
312
-
313
- if model_entry is None:
314
- return False, f"Модель {model_id} версии {version} не найдена в реестре"
315
-
316
- # Проверяем, активна ли модель
317
- if model_entry.get("is_active", False):
318
- return False, "Нельзя удалить активную модель. Сначала установите другую модель как активную."
319
-
320
- # Удаляем директорию с моделью, если она существует
321
- local_path = model_entry["local_path"]
322
- if os.path.exists(local_path):
323
- shutil.rmtree(local_path)
324
-
325
- # Удаляем модель из реестра
326
- self.registry["models"].pop(model_index)
327
- self.save_registry()
328
-
329
- return True, f"Модель {model_id} версии {version} успешно удалена"
330
- except Exception as e:
331
- return False, f"Ошиб��а при удалении модели: {str(e)}"
332
-
333
- def list_models(self, model_id: Optional[str] = None) -> List[Dict[str, Any]]:
334
- """
335
- Получение списка моделей в реестре
336
-
337
- Args:
338
- model_id: Идентификатор модели для фильтрации (если None, возвращаются все модели)
339
-
340
- Returns:
341
- Список словарей с информацией о моделях
342
- """
343
- if model_id is None:
344
- return self.registry["models"]
345
- else:
346
- return [model for model in self.registry["models"] if model["model_id"] == model_id]
347
-
348
- def import_local_model(
349
- self,
350
- source_path: str,
351
- model_id: str,
352
- version: str,
353
- description: str = "",
354
- is_active: bool = False
355
- ) -> Tuple[bool, str]:
356
- """
357
- Импорт локальной модели в реестр
358
-
359
- Args:
360
- source_path: Путь к директории с моделью
361
- model_id: Идентификатор модели
362
- version: Версия модели
363
- description: Описание модели
364
- is_active: Флаг активности модели
365
-
366
- Returns:
367
- (успех, сообщение)
368
- """
369
- try:
370
- # Проверяем существование исходной директории
371
- if not os.path.exists(source_path):
372
- return False, f"Директория {source_path} не существует"
373
-
374
- # Проверяем, что это директория с моделью
375
- if not os.path.exists(os.path.join(source_path, "config.json")):
376
- return False, f"Директория {source_path} не содержит модель трансформера"
377
-
378
- # Создаем путь для модели в нашем хранилище
379
- target_path = os.path.join(self.models_dir, f"{model_id}_{version}")
380
-
381
- # Если директория уже существует, удаляем ее
382
- if os.path.exists(target_path):
383
- shutil.rmtree(target_path)
384
-
385
- # Копируем файлы модели
386
- shutil.copytree(source_path, target_path)
387
-
388
- # Регистрируем модель в реестре
389
- success, message = self.register_model(
390
- model_id=model_id,
391
- version=version,
392
- source=f"local://{source_path}",
393
- description=description,
394
- is_active=is_active
395
- )
396
-
397
- if not success:
398
- # Если регистрация не удалась, удаляем скопированные файлы
399
- shutil.rmtree(target_path)
400
- return False, message
401
-
402
- return True, f"Модель успешно импортирована: {model_id} версии {version}"
403
- except Exception as e:
404
- return False, f"Ошибка при импорте модели: {str(e)}"
405
-
406
- def export_model_metrics(self, output_file: str) -> Tuple[bool, str]:
407
- """
408
- Экспорт метрик всех моделей в JSON файл
409
-
410
- Args:
411
- output_file: Путь к выходному файлу
412
-
413
- Returns:
414
- (успех, сообщение)
415
- """
416
- try:
417
- # Создаем словарь с метриками для каждой модели
418
- metrics_data = {}
419
-
420
- for model in self.registry["models"]:
421
- model_key = f"{model['model_id']}_{model['version']}"
422
- metrics_data[model_key] = {
423
- "model_id": model["model_id"],
424
- "version": model["version"],
425
- "is_active": model.get("is_active", False),
426
- "registration_date": model.get("registration_date", ""),
427
- "metrics": model.get("metrics", {})
428
- }
429
-
430
- # Сохраняем в файл
431
- with open(output_file, "w", encoding="utf-8") as f:
432
- json.dump(metrics_data, f, ensure_ascii=False, indent=2)
433
-
434
- return True, f"Метрики моделей успешно экспортированы в {output_file}"
435
- except Exception as e:
436
- return False, f"Ошибка при экспорте метрик: {str(e)}"
437
-
438
- def update_model_metrics(
439
- self,
440
- model_id: str,
441
- version: str,
442
- metrics: Dict[str, Any]
443
- ) -> Tuple[bool, str]:
444
- """
445
- Обновление метрик модели
446
-
447
- Args:
448
- model_id: Идентификатор модели
449
- version: Версия модели
450
- metrics: Словарь с метриками
451
-
452
- Returns:
453
- (успех, сообщение)
454
- """
455
- try:
456
- # Ищем модель в реестре
457
- model_found = False
458
- for i, model in enumerate(self.registry["models"]):
459
- if model["model_id"] == model_id and model["version"] == version:
460
- # Обновляем метрики
461
- self.registry["models"][i]["metrics"] = metrics
462
- model_found = True
463
- break
464
-
465
- if not model_found:
466
- return False, f"Модель {model_id} версии {version} не найдена в реестре"
467
-
468
- self.save_registry()
469
- return True, f"Метрики модели {model_id} версии {version} успешно обновлены"
470
- except Exception as e:
471
- return False, f"Ошибка при обновлении метрик: {str(e)}"
472
-
473
  def get_model(
474
- model_id: str = "saiga",
475
  version: Optional[str] = None,
476
  device: str = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
477
  ) -> Tuple[Any, Any, Dict[str, Any]]:
478
  """
479
- Удобная функция для получения модели и токенизатора
480
 
481
  Args:
482
- model_id: Идентификатор модели
483
- version: Версия модели (если None, загружается активная версия)
484
- device: Устройство для загрузки модели
485
 
486
  Returns:
487
- (модель, токенизатор, информация о модели)
488
  """
489
  manager = ModelManager()
490
- success, model, tokenizer, message = manager.load_model(
491
- model_id=model_id,
492
- version=version,
493
- device=device
494
- )
495
 
496
- if not success:
497
- logger.error(message)
498
- raise ValueError(message)
499
 
500
- # Получаем информацию о загруженной модели
501
- if version is None:
502
- model_info = manager.get_active_model(model_id)
503
- else:
504
- for m in manager.list_models(model_id):
505
- if m["version"] == version:
506
- model_info = m
507
- break
508
- else:
509
- model_info = {}
510
-
511
- return model, tokenizer, model_info
 
 
 
 
 
512
 
513
  if __name__ == "__main__":
514
  # Пример использования
@@ -528,4 +72,4 @@ if __name__ == "__main__":
528
  models = manager.list_models()
529
  print(f"В реестре {len(models)} моделей:")
530
  for model in models:
531
- print(f" - {model['model_id']} v{model['version']}: {model['description']}")
 
1
  """
2
+ Module for managing models and their versions
3
  """
4
 
5
  import os
6
  import json
 
7
  from datetime import datetime
8
  from typing import List, Dict, Any, Tuple, Optional
9
  import logging
10
+ from huggingface_hub import HfApi, snapshot_download
11
  from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ from config.settings import MODEL_PATH, MODELS_REGISTRY_PATH, MODEL_CONFIG
13
 
 
14
  logging.basicConfig(
15
  level=logging.INFO,
16
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
17
  )
18
  logger = logging.getLogger(__name__)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def get_model(
 
21
  version: Optional[str] = None,
22
  device: str = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
23
  ) -> Tuple[Any, Any, Dict[str, Any]]:
24
  """
25
+ Convenient function to get model and tokenizer
26
 
27
  Args:
28
+ version: Model version (if None, loads base version)
29
+ device: Device for loading model
 
30
 
31
  Returns:
32
+ (model, tokenizer, model_info)
33
  """
34
  manager = ModelManager()
 
 
 
 
 
35
 
36
+ # Use base model if version is None
37
+ model_path = MODEL_CONFIG["training"]["fine_tuned_path"] if version else MODEL_CONFIG["training"]["base_model_path"]
 
38
 
39
+ try:
40
+ tokenizer = AutoTokenizer.from_pretrained(
41
+ model_path,
42
+ trust_remote_code=True
43
+ )
44
+
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ model_path,
47
+ trust_remote_code=True,
48
+ device_map="auto" if device == "cuda" else None
49
+ )
50
+
51
+ return model, tokenizer, MODEL_CONFIG
52
+
53
+ except Exception as e:
54
+ logger.error(f"Error loading model: {str(e)}")
55
+ raise ValueError(f"Failed to load model: {str(e)}")
56
 
57
  if __name__ == "__main__":
58
  # Пример использования
 
72
  models = manager.list_models()
73
  print(f"В реестре {len(models)} моделей:")
74
  for model in models:
75
+ print(f" - {model['model_id']} v{model['version']}: {model['description']}")