MinAA commited on
Commit
ce25353
·
1 Parent(s): 874d30c
Files changed (1) hide show
  1. app.py +101 -24
app.py CHANGED
@@ -303,20 +303,9 @@ def audio_classifier(audio, model_name):
303
  def audio_zero_shot_classifier(audio, candidate_labels, model_name):
304
  """Zero-shot классификация аудио"""
305
  try:
306
- # Используем CLAP для zero-shot классификации аудио
307
- from transformers import ClapProcessor, ClapModel
308
  import soundfile as sf
309
  import numpy as np
310
 
311
- cache_key = f"audio_zero_shot_{model_name}"
312
- cached = model_cache.get(cache_key)
313
- if cached is None:
314
- processor = ClapProcessor.from_pretrained(model_name)
315
- model = ClapModel.from_pretrained(model_name)
316
- cached = (processor, model)
317
- model_cache.put(cache_key, cached)
318
-
319
- processor, model = cached
320
  labels = [label.strip() for label in candidate_labels.split(",")]
321
 
322
  # Загружаем аудио файл, если передан путь
@@ -339,16 +328,99 @@ def audio_zero_shot_classifier(audio, candidate_labels, model_name):
339
  if len(audio_data.shape) > 1:
340
  audio_data = audio_data[:, 0] if audio_data.shape[1] > 0 else audio_data.flatten()
341
 
342
- inputs = processor(text=labels, audios=audio_data, return_tensors="pt", padding=True)
343
- with torch.no_grad():
344
- outputs = model(**inputs)
345
- logits_per_audio = outputs.logits_per_audio
346
- probs = logits_per_audio.softmax(dim=1)
347
-
348
- output = "Результаты классификации:\n"
349
- for label, prob in zip(labels, probs[0]):
350
- output += f"{label}: {prob.item():.4f}\n"
351
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  except Exception as e:
353
  return f"Ошибка: {str(e)}"
354
 
@@ -1258,7 +1330,11 @@ with gr.Blocks(title="Трансформеры Hugging Face", theme=gr.themes.So
1258
  value="music, speech, noise"
1259
  )
1260
  zs_audio_model = gr.Dropdown(
1261
- choices=["laion/clap-htsat-unfused"],
 
 
 
 
1262
  value="laion/clap-htsat-unfused",
1263
  label="Выберите модель"
1264
  )
@@ -1307,10 +1383,11 @@ with gr.Blocks(title="Трансформеры Hugging Face", theme=gr.themes.So
1307
  )
1308
  tts_model = gr.Dropdown(
1309
  choices=[
1310
- "microsoft/speecht5_tts"
 
1311
  ],
1312
  value="microsoft/speecht5_tts",
1313
- label="Выберите модель (поддерживаются только модели SpeechT5 из transformers)"
1314
  )
1315
  tts_btn = gr.Button("Синтезировать", variant="primary")
1316
  with gr.Column():
 
303
  def audio_zero_shot_classifier(audio, candidate_labels, model_name):
304
  """Zero-shot классификация аудио"""
305
  try:
 
 
306
  import soundfile as sf
307
  import numpy as np
308
 
 
 
 
 
 
 
 
 
 
309
  labels = [label.strip() for label in candidate_labels.split(",")]
310
 
311
  # Загружаем аудио файл, если передан путь
 
328
  if len(audio_data.shape) > 1:
329
  audio_data = audio_data[:, 0] if audio_data.shape[1] > 0 else audio_data.flatten()
330
 
331
+ # Проверяем тип модели
332
+ if "clap" in model_name.lower():
333
+ # Используем CLAP для zero-shot классификации аудио
334
+ from transformers import ClapProcessor, ClapModel
335
+
336
+ cache_key = f"audio_zero_shot_{model_name}"
337
+ cached = model_cache.get(cache_key)
338
+ if cached is None:
339
+ processor = ClapProcessor.from_pretrained(model_name)
340
+ model = ClapModel.from_pretrained(model_name)
341
+ cached = (processor, model)
342
+ model_cache.put(cache_key, cached)
343
+
344
+ processor, model = cached
345
+ inputs = processor(text=labels, audios=audio_data, return_tensors="pt", padding=True)
346
+ with torch.no_grad():
347
+ outputs = model(**inputs)
348
+ logits_per_audio = outputs.logits_per_audio
349
+ probs = logits_per_audio.softmax(dim=1)
350
+
351
+ output = "Результаты классификации:\n"
352
+ for label, prob in zip(labels, probs[0]):
353
+ output += f"{label}: {prob.item():.4f}\n"
354
+ return output
355
+ elif "wav2vec2" in model_name.lower() or "hubert" in model_name.lower():
356
+ # Используем подход с audio embeddings + text embeddings
357
+ # Получаем аудио эмбеддинги через audio model и текстовые эмбеддинги через text model
358
+ from transformers import AutoProcessor, AutoModel
359
+ from sentence_transformers import SentenceTransformer
360
+
361
+ cache_key = f"audio_zero_shot_{model_name}"
362
+ cached = model_cache.get(cache_key)
363
+ if cached is None:
364
+ # Загружаем модель для аудио эмбеддингов
365
+ audio_processor = AutoProcessor.from_pretrained(model_name)
366
+ audio_model = AutoModel.from_pretrained(model_name)
367
+ # Загружаем модель для текстовых эмбеддингов
368
+ text_model = SentenceTransformer('all-MiniLM-L6-v2')
369
+ cached = (audio_processor, audio_model, text_model)
370
+ model_cache.put(cache_key, cached)
371
+
372
+ audio_processor, audio_model, text_model = cached
373
+
374
+ # Получаем аудио эмбеддинги
375
+ # Нормализуем sample rate если нужно
376
+ if sample_rate is None:
377
+ sample_rate = 16000
378
+ inputs = audio_processor(audio_data, sampling_rate=sample_rate, return_tensors="pt")
379
+ with torch.no_grad():
380
+ audio_outputs = audio_model(**inputs)
381
+ # Используем последний скрытый слой как эмбеддинг
382
+ if hasattr(audio_outputs, 'last_hidden_state'):
383
+ audio_embedding = audio_outputs.last_hidden_state.mean(dim=1) # Усредняем по временной оси
384
+ else:
385
+ audio_embedding = audio_outputs[0].mean(dim=1)
386
+ audio_embedding = audio_embedding / audio_embedding.norm(dim=1, keepdim=True)
387
+
388
+ # Получаем текстовые эмбеддинги
389
+ text_embeddings = text_model.encode(labels, convert_to_tensor=True)
390
+ text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
391
+
392
+ # Вычисляем косинусное сходство
393
+ similarities = cosine_similarity(audio_embedding, text_embeddings).squeeze(0)
394
+ # Применяем softmax для получения вероятностей
395
+ probs = torch.softmax(similarities * 10, dim=0) # Масштабируем для лучшей точности
396
+
397
+ output = "Результаты классификации (через audio + text embeddings):\n"
398
+ for label, prob in zip(labels, probs):
399
+ output += f"{label}: {prob.item():.4f}\n"
400
+ return output
401
+ else:
402
+ # Для других моделей используем CLAP по умолчанию
403
+ from transformers import ClapProcessor, ClapModel
404
+
405
+ cache_key = f"audio_zero_shot_{model_name}"
406
+ cached = model_cache.get(cache_key)
407
+ if cached is None:
408
+ processor = ClapProcessor.from_pretrained(model_name)
409
+ model = ClapModel.from_pretrained(model_name)
410
+ cached = (processor, model)
411
+ model_cache.put(cache_key, cached)
412
+
413
+ processor, model = cached
414
+ inputs = processor(text=labels, audios=audio_data, return_tensors="pt", padding=True)
415
+ with torch.no_grad():
416
+ outputs = model(**inputs)
417
+ logits_per_audio = outputs.logits_per_audio
418
+ probs = logits_per_audio.softmax(dim=1)
419
+
420
+ output = "Результаты классификации:\n"
421
+ for label, prob in zip(labels, probs[0]):
422
+ output += f"{label}: {prob.item():.4f}\n"
423
+ return output
424
  except Exception as e:
425
  return f"Ошибка: {str(e)}"
426
 
 
1330
  value="music, speech, noise"
1331
  )
1332
  zs_audio_model = gr.Dropdown(
1333
+ choices=[
1334
+ "laion/clap-htsat-unfused",
1335
+ "laion/clap-htsat-fused",
1336
+ "facebook/wav2vec2-base-960h"
1337
+ ],
1338
  value="laion/clap-htsat-unfused",
1339
  label="Выберите модель"
1340
  )
 
1383
  )
1384
  tts_model = gr.Dropdown(
1385
  choices=[
1386
+ "microsoft/speecht5_tts",
1387
+ "facebook/mms-tts-eng"
1388
  ],
1389
  value="microsoft/speecht5_tts",
1390
+ label="Выберите модель"
1391
  )
1392
  tts_btn = gr.Button("Синтезировать", variant="primary")
1393
  with gr.Column():