MinAA commited on
Commit
208683f
·
1 Parent(s): 5c7b4ce
Files changed (1) hide show
  1. app.py +137 -3
app.py CHANGED
@@ -365,8 +365,9 @@ def audio_zero_shot_classifier(audio, candidate_labels, model_name):
365
  # Загружаем модель для аудио эмбеддингов
366
  audio_processor = AutoProcessor.from_pretrained(model_name)
367
  audio_model = AutoModel.from_pretrained(model_name)
368
- # Загружаем модель для текстовых эмбеддингов
369
- text_model = SentenceTransformer('all-MiniLM-L6-v2')
 
370
  cached = (audio_processor, audio_model, text_model)
371
  model_cache.put(cache_key, cached)
372
 
@@ -396,6 +397,23 @@ def audio_zero_shot_classifier(audio, candidate_labels, model_name):
396
  text_embeddings = text_model.encode(labels, convert_to_tensor=True)
397
  text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  # Вычисляем косинусное сходство
400
  similarities = cosine_similarity(audio_embedding, text_embeddings).squeeze(0)
401
  # Применяем softmax для получения вероятностей
@@ -503,6 +521,24 @@ def speech_synthesis(text, model_name):
503
 
504
  # Для других моделей используем стандартный pipeline
505
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  tts = get_pipeline("text-to-speech", model_name)
507
  result = tts(text)
508
  except Exception as e:
@@ -513,6 +549,12 @@ def speech_synthesis(text, model_name):
513
  f"Эта модель может требовать другую библиотеку (например, Fairseq или ESPnet). "
514
  f"Попробуйте использовать модель microsoft/speecht5_tts, которая полностью поддерживается."
515
  ) from e
 
 
 
 
 
 
516
  raise
517
 
518
  # Pipeline может возвращать словарь или кортеж
@@ -686,12 +728,16 @@ def image_segmentation(image, model_name):
686
  overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
687
  draw = ImageDraw.Draw(overlay)
688
 
 
 
 
689
  for i, item in enumerate(result):
690
  label = item['label']
691
  score = item['score']
692
 
693
  # Генерируем полупрозрачный цвет для сегмента
694
  color = tuple(np.random.randint(0, 255, 3)) + (128,) # RGBA с прозрачностью
 
695
 
696
  # Проверяем наличие маски
697
  if 'mask' in item:
@@ -712,6 +758,38 @@ def image_segmentation(image, model_name):
712
  else:
713
  mask_array = mask_array.astype(np.uint8)
714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715
  # Создаем цветную маску
716
  if len(mask_array.shape) == 2: # Grayscale mask
717
  # Создаем RGBA маску
@@ -751,6 +829,61 @@ def image_segmentation(image, model_name):
751
  if overlay.size == img_with_segments.size:
752
  img_with_segments = Image.alpha_composite(img_with_segments, overlay)
753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
  # Конвертируем обратно в RGB для отображения
755
  img_with_segments = img_with_segments.convert("RGB")
756
 
@@ -1391,7 +1524,8 @@ with gr.Blocks(title="Трансформеры Hugging Face", theme=gr.themes.So
1391
  tts_model = gr.Dropdown(
1392
  choices=[
1393
  "microsoft/speecht5_tts",
1394
- "facebook/mms-tts-eng"
 
1395
  ],
1396
  value="microsoft/speecht5_tts",
1397
  label="Выберите модель"
 
365
  # Загружаем модель для аудио эмбеддингов
366
  audio_processor = AutoProcessor.from_pretrained(model_name)
367
  audio_model = AutoModel.from_pretrained(model_name)
368
+ # Загружаем модель для текстовых эмбеддингов с размерностью 768
369
+ # Используем модель с размерностью 768 для совместимости с Wav2Vec2
370
+ text_model = SentenceTransformer('all-mpnet-base-v2')
371
  cached = (audio_processor, audio_model, text_model)
372
  model_cache.put(cache_key, cached)
373
 
 
397
  text_embeddings = text_model.encode(labels, convert_to_tensor=True)
398
  text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
399
 
400
+ # Проверяем размерности и проецируем если нужно
401
+ audio_dim = audio_embedding.shape[1]
402
+ text_dim = text_embeddings.shape[1]
403
+
404
+ if audio_dim != text_dim:
405
+ # Если размерности не совпадают, проецируем меньший эмбеддинг в большее пространство
406
+ if audio_dim > text_dim:
407
+ # Проецируем текстовые эмбеддинги в пространство аудио
408
+ projection = torch.nn.Linear(text_dim, audio_dim).to(text_embeddings.device)
409
+ text_embeddings = projection(text_embeddings)
410
+ text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
411
+ else:
412
+ # Проецируем аудио эмбеддинги в пространство текста
413
+ projection = torch.nn.Linear(audio_dim, text_dim).to(audio_embedding.device)
414
+ audio_embedding = projection(audio_embedding)
415
+ audio_embedding = audio_embedding / audio_embedding.norm(dim=1, keepdim=True)
416
+
417
  # Вычисляем косинусное сходство
418
  similarities = cosine_similarity(audio_embedding, text_embeddings).squeeze(0)
419
  # Применяем softmax для получения вероятностей
 
521
 
522
  # Для других моделей используем стандартный pipeline
523
  try:
524
+ # Проверяем, что текст не пустой
525
+ if not text or not text.strip():
526
+ raise ValueError("Текст для синтеза не может быть пустым")
527
+
528
+ # Для MMS TTS моделей проверяем язык
529
+ if "mms-tts" in model_name.lower():
530
+ # MMS TTS модели обычно поддерживают только один язык
531
+ # eng - английский, rus - русский и т.д.
532
+ if "mms-tts-eng" in model_name.lower():
533
+ # Проверяем, что текст на английском (простая проверка)
534
+ # Если текст содержит кириллицу, это может быть проблемой
535
+ has_cyrillic = any('\u0400' <= char <= '\u04FF' for char in text)
536
+ if has_cyrillic:
537
+ raise ValueError(
538
+ f"Модель '{model_name}' поддерживает только английский язык. "
539
+ f"Для русского текста используйте модель 'facebook/mms-tts-rus' или 'microsoft/speecht5_tts'."
540
+ )
541
+
542
  tts = get_pipeline("text-to-speech", model_name)
543
  result = tts(text)
544
  except Exception as e:
 
549
  f"Эта модель может требовать другую библиотеку (например, Fairseq или ESPnet). "
550
  f"Попробуйте использовать модель microsoft/speecht5_tts, которая полностью поддерживается."
551
  ) from e
552
+ elif "negative output size" in error_msg.lower() or "input size 0" in error_msg.lower():
553
+ raise ValueError(
554
+ f"Ошибка обработки текста моделью '{model_name}'. "
555
+ f"Возможные причины: неподдерживаемый язык, пустой текст после обработки, или проблема с токенизацией. "
556
+ f"Попробуйте использовать другую модель или проверьте язык текста."
557
+ ) from e
558
  raise
559
 
560
  # Pipeline может возвращать словарь или кортеж
 
728
  overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
729
  draw = ImageDraw.Draw(overlay)
730
 
731
+ # Список для хранения информации о сегментах (для добавления текста)
732
+ segments_info = []
733
+
734
  for i, item in enumerate(result):
735
  label = item['label']
736
  score = item['score']
737
 
738
  # Генерируем полупрозрачный цвет для сегмента
739
  color = tuple(np.random.randint(0, 255, 3)) + (128,) # RGBA с прозрачностью
740
+ color_rgb = color[:3] # RGB цвет для текста
741
 
742
  # Проверяем наличие маски
743
  if 'mask' in item:
 
758
  else:
759
  mask_array = mask_array.astype(np.uint8)
760
 
761
+ # Находим центр маски для размещения текста
762
+ if len(mask_array.shape) == 2: # Grayscale mask
763
+ mask_bool = mask_array > 0
764
+ elif len(mask_array.shape) == 3 and mask_array.shape[2] == 1:
765
+ mask_bool = mask_array[:, :, 0] > 0
766
+ else:
767
+ if mask_array.shape[2] >= 1:
768
+ mask_bool = mask_array[:, :, 0] > 0
769
+ else:
770
+ mask_bool = np.zeros(mask_array.shape[:2], dtype=bool)
771
+
772
+ # Вычисляем центр маски
773
+ if np.any(mask_bool):
774
+ y_coords, x_coords = np.where(mask_bool)
775
+ if len(y_coords) > 0 and len(x_coords) > 0:
776
+ center_y = int(np.mean(y_coords))
777
+ center_x = int(np.mean(x_coords))
778
+
779
+ # Масштабируем координаты, если маска другого размера
780
+ if mask_array.shape[:2] != image.size[::-1]:
781
+ scale_y = image.size[1] / mask_array.shape[0]
782
+ scale_x = image.size[0] / mask_array.shape[1]
783
+ center_y = int(center_y * scale_y)
784
+ center_x = int(center_x * scale_x)
785
+
786
+ segments_info.append({
787
+ 'label': label,
788
+ 'score': score,
789
+ 'center': (center_x, center_y),
790
+ 'color': color_rgb
791
+ })
792
+
793
  # Создаем цветную маску
794
  if len(mask_array.shape) == 2: # Grayscale mask
795
  # Создаем RGBA маску
 
829
  if overlay.size == img_with_segments.size:
830
  img_with_segments = Image.alpha_composite(img_with_segments, overlay)
831
 
832
+ # Добавляем текстовые метки с цветами на изображение
833
+ draw_final = ImageDraw.Draw(img_with_segments)
834
+
835
+ # Загружаем шрифт
836
+ try:
837
+ font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 18)
838
+ except:
839
+ try:
840
+ font = ImageFont.load_default()
841
+ except:
842
+ font = None
843
+
844
+ for seg_info in segments_info:
845
+ label = seg_info['label']
846
+ score = seg_info['score']
847
+ center_x, center_y = seg_info['center']
848
+ color_rgb = seg_info['color']
849
+
850
+ # Формируем текст метки
851
+ text = f"{label}: {score:.2f}"
852
+
853
+ # Получаем размер текста
854
+ if font:
855
+ bbox = draw_final.textbbox((0, 0), text, font=font)
856
+ text_width = bbox[2] - bbox[0]
857
+ text_height = bbox[3] - bbox[1]
858
+ else:
859
+ text_width = len(text) * 7
860
+ text_height = 14
861
+
862
+ # Вычисляем позицию текста (центрируем относительно центра сегмента)
863
+ text_x = center_x - text_width // 2
864
+ text_y = center_y - text_height // 2
865
+
866
+ # Ограничиваем координаты границами изображения
867
+ img_width, img_height = img_with_segments.size
868
+ text_x = max(2, min(text_x, img_width - text_width - 2))
869
+ text_y = max(2, min(text_y, img_height - text_height - 2))
870
+
871
+ # Рисуем фон для текста (полупрозрачный черный для читаемости)
872
+ padding = 4
873
+ draw_final.rectangle(
874
+ [text_x - padding, text_y - padding,
875
+ text_x + text_width + padding, text_y + text_height + padding],
876
+ fill=(0, 0, 0, 180) # Полупрозрачный черный фон
877
+ )
878
+
879
+ # Рисуем текст цветом сегмента
880
+ draw_final.text(
881
+ (text_x, text_y),
882
+ text,
883
+ fill=color_rgb + (255,), # RGB + альфа для RGBA
884
+ font=font
885
+ )
886
+
887
  # Конвертируем обратно в RGB для отображения
888
  img_with_segments = img_with_segments.convert("RGB")
889
 
 
1524
  tts_model = gr.Dropdown(
1525
  choices=[
1526
  "microsoft/speecht5_tts",
1527
+ "facebook/mms-tts-eng",
1528
+ "facebook/mms-tts-rus"
1529
  ],
1530
  value="microsoft/speecht5_tts",
1531
  label="Выберите модель"