MinAA commited on
Commit
859c222
·
1 Parent(s): 3832d6e
Files changed (1) hide show
  1. app.py +110 -29
app.py CHANGED
@@ -11,24 +11,90 @@ import warnings
11
  import time
12
  import inspect
13
  from datetime import datetime
 
14
  warnings.filterwarnings("ignore")
15
 
16
- # Кэш для хранения загруженных моделей
17
- model_cache = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # История выполнения моделей
20
  history = []
21
  MAX_HISTORY_SIZE = 50
22
 
23
  def get_pipeline(task, model_name, **kwargs):
24
- """Загрузка pipeline с кэшированием"""
25
  cache_key = f"{task}_{model_name}"
26
- if cache_key not in model_cache:
 
27
  try:
28
- model_cache[cache_key] = pipeline(task, model=model_name, **kwargs)
 
29
  except Exception as e:
30
  raise Exception(f"Ошибка загрузки модели: {str(e)}")
31
- return model_cache[cache_key]
32
 
33
  def measure_time_and_save(task_name):
34
  """Декоратор для измерения времени выполнения и сохранения в историю"""
@@ -214,8 +280,9 @@ def audio_classifier(audio, model_name):
214
  try:
215
  classifier = get_pipeline("audio-classification", model_name)
216
  result = classifier(audio)
217
- if isinstance(result, list):
218
- result = result[0]
 
219
  output = "Результаты классификации:\n"
220
  for item in result[:5]:
221
  output += f"{item['label']}: {item['score']:.4f}\n"
@@ -230,12 +297,14 @@ def audio_zero_shot_classifier(audio, candidate_labels, model_name):
230
  # Используем CLAP для zero-shot классификации аудио
231
  from transformers import ClapProcessor, ClapModel
232
  cache_key = f"audio_zero_shot_{model_name}"
233
- if cache_key not in model_cache:
 
234
  processor = ClapProcessor.from_pretrained(model_name)
235
  model = ClapModel.from_pretrained(model_name)
236
- model_cache[cache_key] = (processor, model)
 
237
 
238
- processor, model = model_cache[cache_key]
239
  labels = [label.strip() for label in candidate_labels.split(",")]
240
 
241
  inputs = processor(text=labels, audios=audio, return_tensors="pt", padding=True)
@@ -273,15 +342,17 @@ def speech_synthesis(text, model_name):
273
  from datasets import load_dataset
274
 
275
  cache_key = f"tts_{model_name}"
276
- if cache_key not in model_cache:
 
277
  processor = SpeechT5Processor.from_pretrained(model_name)
278
  model = SpeechT5ForTextToSpeech.from_pretrained(model_name)
279
  vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
280
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
281
  speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
282
- model_cache[cache_key] = (processor, model, vocoder, speaker_embeddings)
 
283
 
284
- processor, model, vocoder, speaker_embeddings = model_cache[cache_key]
285
  inputs = processor(text=text, return_tensors="pt")
286
  with torch.no_grad():
287
  speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
@@ -331,12 +402,14 @@ def image_text_matching(image, text, model_name):
331
  """Сопоставление изображения и текста"""
332
  try:
333
  cache_key = f"clip_{model_name}"
334
- if cache_key not in model_cache:
 
335
  processor = CLIPProcessor.from_pretrained(model_name)
336
  model = CLIPModel.from_pretrained(model_name)
337
- model_cache[cache_key] = (processor, model)
 
338
 
339
- processor, model = model_cache[cache_key]
340
  inputs = processor(text=[text], images=image, return_tensors="pt", padding=True)
341
 
342
  with torch.no_grad():
@@ -355,12 +428,14 @@ def image_captioning(image, model_name):
355
  try:
356
  if "blip" in model_name.lower():
357
  cache_key = f"caption_blip_{model_name}"
358
- if cache_key not in model_cache:
 
359
  processor = BlipProcessor.from_pretrained(model_name)
360
  model = BlipForConditionalGeneration.from_pretrained(model_name)
361
- model_cache[cache_key] = (processor, model)
 
362
 
363
- processor, model = model_cache[cache_key]
364
  inputs = processor(image, return_tensors="pt")
365
  out = model.generate(**inputs, max_length=50)
366
  caption = processor.decode(out[0], skip_special_tokens=True)
@@ -380,12 +455,14 @@ def visual_qa(image, question, model_name):
380
  try:
381
  if "vilt" in model_name.lower():
382
  cache_key = f"vqa_vilt_{model_name}"
383
- if cache_key not in model_cache:
 
384
  processor = ViltProcessor.from_pretrained(model_name)
385
  model = ViltForQuestionAnswering.from_pretrained(model_name)
386
- model_cache[cache_key] = (processor, model)
 
387
 
388
- processor, model = model_cache[cache_key]
389
  inputs = processor(image, question, return_tensors="pt")
390
  outputs = model(**inputs)
391
  logits = outputs.logits
@@ -394,12 +471,14 @@ def visual_qa(image, question, model_name):
394
  return f"Ответ: {answer}"
395
  elif "blip" in model_name.lower():
396
  cache_key = f"vqa_blip_{model_name}"
397
- if cache_key not in model_cache:
 
398
  processor = BlipProcessor.from_pretrained(model_name)
399
  model = BlipForConditionalGeneration.from_pretrained(model_name)
400
- model_cache[cache_key] = (processor, model)
 
401
 
402
- processor, model = model_cache[cache_key]
403
  inputs = processor(image, question, return_tensors="pt")
404
  out = model.generate(**inputs, max_length=50)
405
  answer = processor.decode(out[0], skip_special_tokens=True)
@@ -418,12 +497,14 @@ def image_zero_shot_classification(image, candidate_labels, model_name):
418
  """Zero-shot классификация изображений"""
419
  try:
420
  cache_key = f"clip_zs_{model_name}"
421
- if cache_key not in model_cache:
 
422
  processor = CLIPProcessor.from_pretrained(model_name)
423
  model = CLIPModel.from_pretrained(model_name)
424
- model_cache[cache_key] = (processor, model)
 
425
 
426
- processor, model = model_cache[cache_key]
427
  labels = [label.strip() for label in candidate_labels.split(",")]
428
  inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
429
 
 
11
  import time
12
  import inspect
13
  from datetime import datetime
14
+ from collections import OrderedDict
15
  warnings.filterwarnings("ignore")
16
 
17
+ # LRU кэш для хранения загруженных моделей
18
+ class LRUCache:
19
+ """LRU (Least Recently Used) кэш для ограничения использования памяти"""
20
+ def __init__(self, maxsize=5):
21
+ """
22
+ Args:
23
+ maxsize: Максимальное количество моделей в кэше
24
+ """
25
+ self.cache = OrderedDict()
26
+ self.maxsize = maxsize
27
+
28
+ def get(self, key):
29
+ """Получить модель из кэша"""
30
+ if key not in self.cache:
31
+ return None
32
+ # Перемещаем элемент в конец (как недавно использованный)
33
+ self.cache.move_to_end(key)
34
+ return self.cache[key]
35
+
36
+ def put(self, key, value):
37
+ """Добавить модель в кэш"""
38
+ if key in self.cache:
39
+ # Если ключ уже есть, обновляем и перемещаем в конец
40
+ self.cache.move_to_end(key)
41
+ self.cache[key] = value
42
+ else:
43
+ # Если кэш полон, удаляем самый старый элемент (первый в OrderedDict)
44
+ if len(self.cache) >= self.maxsize:
45
+ oldest_key = next(iter(self.cache))
46
+ # Освобождаем память от модели
47
+ old_value = self.cache.pop(oldest_key)
48
+ del old_value
49
+ # Также очищаем кэш CUDA если используется GPU
50
+ if torch.cuda.is_available():
51
+ torch.cuda.empty_cache()
52
+ self.cache[key] = value
53
+
54
+ def __contains__(self, key):
55
+ """Проверка наличия ключа в кэше"""
56
+ return key in self.cache
57
+
58
+ def __getitem__(self, key):
59
+ """Получить элемент через []"""
60
+ value = self.get(key)
61
+ if value is None:
62
+ raise KeyError(key)
63
+ return value
64
+
65
+ def __setitem__(self, key, value):
66
+ """Установить элемент через []"""
67
+ self.put(key, value)
68
+
69
+ def clear(self):
70
+ """Очистить кэш"""
71
+ self.cache.clear()
72
+ if torch.cuda.is_available():
73
+ torch.cuda.empty_cache()
74
+
75
+ def size(self):
76
+ """Текущий размер кэша"""
77
+ return len(self.cache)
78
+
79
+ # Создаем LRU кэш с максимальным размером 5 моделей
80
+ # Можно изменить это значение в зависимости от доступной памяти
81
+ model_cache = LRUCache(maxsize=5)
82
 
83
  # История выполнения моделей
84
  history = []
85
  MAX_HISTORY_SIZE = 50
86
 
87
  def get_pipeline(task, model_name, **kwargs):
88
+ """Загрузка pipeline с LRU кэшированием"""
89
  cache_key = f"{task}_{model_name}"
90
+ cached_model = model_cache.get(cache_key)
91
+ if cached_model is None:
92
  try:
93
+ cached_model = pipeline(task, model=model_name, **kwargs)
94
+ model_cache.put(cache_key, cached_model)
95
  except Exception as e:
96
  raise Exception(f"Ошибка загрузки модели: {str(e)}")
97
+ return cached_model
98
 
99
  def measure_time_and_save(task_name):
100
  """Декоратор для измерения времени выполнения и сохранения в историю"""
 
280
  try:
281
  classifier = get_pipeline("audio-classification", model_name)
282
  result = classifier(audio)
283
+ # audio-classification pipeline возвращает список словарей
284
+ if not isinstance(result, list):
285
+ result = [result]
286
  output = "Результаты классификации:\n"
287
  for item in result[:5]:
288
  output += f"{item['label']}: {item['score']:.4f}\n"
 
297
  # Используем CLAP для zero-shot классификации аудио
298
  from transformers import ClapProcessor, ClapModel
299
  cache_key = f"audio_zero_shot_{model_name}"
300
+ cached = model_cache.get(cache_key)
301
+ if cached is None:
302
  processor = ClapProcessor.from_pretrained(model_name)
303
  model = ClapModel.from_pretrained(model_name)
304
+ cached = (processor, model)
305
+ model_cache.put(cache_key, cached)
306
 
307
+ processor, model = cached
308
  labels = [label.strip() for label in candidate_labels.split(",")]
309
 
310
  inputs = processor(text=labels, audios=audio, return_tensors="pt", padding=True)
 
342
  from datasets import load_dataset
343
 
344
  cache_key = f"tts_{model_name}"
345
+ cached = model_cache.get(cache_key)
346
+ if cached is None:
347
  processor = SpeechT5Processor.from_pretrained(model_name)
348
  model = SpeechT5ForTextToSpeech.from_pretrained(model_name)
349
  vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
350
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
351
  speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
352
+ cached = (processor, model, vocoder, speaker_embeddings)
353
+ model_cache.put(cache_key, cached)
354
 
355
+ processor, model, vocoder, speaker_embeddings = cached
356
  inputs = processor(text=text, return_tensors="pt")
357
  with torch.no_grad():
358
  speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
 
402
  """Сопоставление изображения и текста"""
403
  try:
404
  cache_key = f"clip_{model_name}"
405
+ cached = model_cache.get(cache_key)
406
+ if cached is None:
407
  processor = CLIPProcessor.from_pretrained(model_name)
408
  model = CLIPModel.from_pretrained(model_name)
409
+ cached = (processor, model)
410
+ model_cache.put(cache_key, cached)
411
 
412
+ processor, model = cached
413
  inputs = processor(text=[text], images=image, return_tensors="pt", padding=True)
414
 
415
  with torch.no_grad():
 
428
  try:
429
  if "blip" in model_name.lower():
430
  cache_key = f"caption_blip_{model_name}"
431
+ cached = model_cache.get(cache_key)
432
+ if cached is None:
433
  processor = BlipProcessor.from_pretrained(model_name)
434
  model = BlipForConditionalGeneration.from_pretrained(model_name)
435
+ cached = (processor, model)
436
+ model_cache.put(cache_key, cached)
437
 
438
+ processor, model = cached
439
  inputs = processor(image, return_tensors="pt")
440
  out = model.generate(**inputs, max_length=50)
441
  caption = processor.decode(out[0], skip_special_tokens=True)
 
455
  try:
456
  if "vilt" in model_name.lower():
457
  cache_key = f"vqa_vilt_{model_name}"
458
+ cached = model_cache.get(cache_key)
459
+ if cached is None:
460
  processor = ViltProcessor.from_pretrained(model_name)
461
  model = ViltForQuestionAnswering.from_pretrained(model_name)
462
+ cached = (processor, model)
463
+ model_cache.put(cache_key, cached)
464
 
465
+ processor, model = cached
466
  inputs = processor(image, question, return_tensors="pt")
467
  outputs = model(**inputs)
468
  logits = outputs.logits
 
471
  return f"Ответ: {answer}"
472
  elif "blip" in model_name.lower():
473
  cache_key = f"vqa_blip_{model_name}"
474
+ cached = model_cache.get(cache_key)
475
+ if cached is None:
476
  processor = BlipProcessor.from_pretrained(model_name)
477
  model = BlipForConditionalGeneration.from_pretrained(model_name)
478
+ cached = (processor, model)
479
+ model_cache.put(cache_key, cached)
480
 
481
+ processor, model = cached
482
  inputs = processor(image, question, return_tensors="pt")
483
  out = model.generate(**inputs, max_length=50)
484
  answer = processor.decode(out[0], skip_special_tokens=True)
 
497
  """Zero-shot классификация изображений"""
498
  try:
499
  cache_key = f"clip_zs_{model_name}"
500
+ cached = model_cache.get(cache_key)
501
+ if cached is None:
502
  processor = CLIPProcessor.from_pretrained(model_name)
503
  model = CLIPModel.from_pretrained(model_name)
504
+ cached = (processor, model)
505
+ model_cache.put(cache_key, cached)
506
 
507
+ processor, model = cached
508
  labels = [label.strip() for label in candidate_labels.split(",")]
509
  inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
510