Spaces:
Running
Running
MinAA
commited on
Commit
·
ce25353
1
Parent(s):
874d30c
init
Browse files
app.py
CHANGED
|
@@ -303,20 +303,9 @@ def audio_classifier(audio, model_name):
|
|
| 303 |
def audio_zero_shot_classifier(audio, candidate_labels, model_name):
|
| 304 |
"""Zero-shot классификация аудио"""
|
| 305 |
try:
|
| 306 |
-
# Используем CLAP для zero-shot классификации аудио
|
| 307 |
-
from transformers import ClapProcessor, ClapModel
|
| 308 |
import soundfile as sf
|
| 309 |
import numpy as np
|
| 310 |
|
| 311 |
-
cache_key = f"audio_zero_shot_{model_name}"
|
| 312 |
-
cached = model_cache.get(cache_key)
|
| 313 |
-
if cached is None:
|
| 314 |
-
processor = ClapProcessor.from_pretrained(model_name)
|
| 315 |
-
model = ClapModel.from_pretrained(model_name)
|
| 316 |
-
cached = (processor, model)
|
| 317 |
-
model_cache.put(cache_key, cached)
|
| 318 |
-
|
| 319 |
-
processor, model = cached
|
| 320 |
labels = [label.strip() for label in candidate_labels.split(",")]
|
| 321 |
|
| 322 |
# Загружаем аудио файл, если передан путь
|
|
@@ -339,16 +328,99 @@ def audio_zero_shot_classifier(audio, candidate_labels, model_name):
|
|
| 339 |
if len(audio_data.shape) > 1:
|
| 340 |
audio_data = audio_data[:, 0] if audio_data.shape[1] > 0 else audio_data.flatten()
|
| 341 |
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
except Exception as e:
|
| 353 |
return f"Ошибка: {str(e)}"
|
| 354 |
|
|
@@ -1258,7 +1330,11 @@ with gr.Blocks(title="Трансформеры Hugging Face", theme=gr.themes.So
|
|
| 1258 |
value="music, speech, noise"
|
| 1259 |
)
|
| 1260 |
zs_audio_model = gr.Dropdown(
|
| 1261 |
-
choices=[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1262 |
value="laion/clap-htsat-unfused",
|
| 1263 |
label="Выберите модель"
|
| 1264 |
)
|
|
@@ -1307,10 +1383,11 @@ with gr.Blocks(title="Трансформеры Hugging Face", theme=gr.themes.So
|
|
| 1307 |
)
|
| 1308 |
tts_model = gr.Dropdown(
|
| 1309 |
choices=[
|
| 1310 |
-
"microsoft/speecht5_tts"
|
|
|
|
| 1311 |
],
|
| 1312 |
value="microsoft/speecht5_tts",
|
| 1313 |
-
label="Выберите модель
|
| 1314 |
)
|
| 1315 |
tts_btn = gr.Button("Синтезировать", variant="primary")
|
| 1316 |
with gr.Column():
|
|
|
|
| 303 |
def audio_zero_shot_classifier(audio, candidate_labels, model_name):
|
| 304 |
"""Zero-shot классификация аудио"""
|
| 305 |
try:
|
|
|
|
|
|
|
| 306 |
import soundfile as sf
|
| 307 |
import numpy as np
|
| 308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
labels = [label.strip() for label in candidate_labels.split(",")]
|
| 310 |
|
| 311 |
# Загружаем аудио файл, если передан путь
|
|
|
|
| 328 |
if len(audio_data.shape) > 1:
|
| 329 |
audio_data = audio_data[:, 0] if audio_data.shape[1] > 0 else audio_data.flatten()
|
| 330 |
|
| 331 |
+
# Проверяем тип модели
|
| 332 |
+
if "clap" in model_name.lower():
|
| 333 |
+
# Используем CLAP для zero-shot классификации аудио
|
| 334 |
+
from transformers import ClapProcessor, ClapModel
|
| 335 |
+
|
| 336 |
+
cache_key = f"audio_zero_shot_{model_name}"
|
| 337 |
+
cached = model_cache.get(cache_key)
|
| 338 |
+
if cached is None:
|
| 339 |
+
processor = ClapProcessor.from_pretrained(model_name)
|
| 340 |
+
model = ClapModel.from_pretrained(model_name)
|
| 341 |
+
cached = (processor, model)
|
| 342 |
+
model_cache.put(cache_key, cached)
|
| 343 |
+
|
| 344 |
+
processor, model = cached
|
| 345 |
+
inputs = processor(text=labels, audios=audio_data, return_tensors="pt", padding=True)
|
| 346 |
+
with torch.no_grad():
|
| 347 |
+
outputs = model(**inputs)
|
| 348 |
+
logits_per_audio = outputs.logits_per_audio
|
| 349 |
+
probs = logits_per_audio.softmax(dim=1)
|
| 350 |
+
|
| 351 |
+
output = "Результаты классификации:\n"
|
| 352 |
+
for label, prob in zip(labels, probs[0]):
|
| 353 |
+
output += f"{label}: {prob.item():.4f}\n"
|
| 354 |
+
return output
|
| 355 |
+
elif "wav2vec2" in model_name.lower() or "hubert" in model_name.lower():
|
| 356 |
+
# Используем подход с audio embeddings + text embeddings
|
| 357 |
+
# Получаем аудио эмбеддинги через audio model и текстовые эмбеддинги через text model
|
| 358 |
+
from transformers import AutoProcessor, AutoModel
|
| 359 |
+
from sentence_transformers import SentenceTransformer
|
| 360 |
+
|
| 361 |
+
cache_key = f"audio_zero_shot_{model_name}"
|
| 362 |
+
cached = model_cache.get(cache_key)
|
| 363 |
+
if cached is None:
|
| 364 |
+
# Загружаем модель для аудио эмбеддингов
|
| 365 |
+
audio_processor = AutoProcessor.from_pretrained(model_name)
|
| 366 |
+
audio_model = AutoModel.from_pretrained(model_name)
|
| 367 |
+
# Загружаем модель для текстовых эмбеддингов
|
| 368 |
+
text_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 369 |
+
cached = (audio_processor, audio_model, text_model)
|
| 370 |
+
model_cache.put(cache_key, cached)
|
| 371 |
+
|
| 372 |
+
audio_processor, audio_model, text_model = cached
|
| 373 |
+
|
| 374 |
+
# Получаем аудио эмбеддинги
|
| 375 |
+
# Нормализуем sample rate если нужно
|
| 376 |
+
if sample_rate is None:
|
| 377 |
+
sample_rate = 16000
|
| 378 |
+
inputs = audio_processor(audio_data, sampling_rate=sample_rate, return_tensors="pt")
|
| 379 |
+
with torch.no_grad():
|
| 380 |
+
audio_outputs = audio_model(**inputs)
|
| 381 |
+
# Используем последний скрытый слой как эмбеддинг
|
| 382 |
+
if hasattr(audio_outputs, 'last_hidden_state'):
|
| 383 |
+
audio_embedding = audio_outputs.last_hidden_state.mean(dim=1) # Усредняем по временной оси
|
| 384 |
+
else:
|
| 385 |
+
audio_embedding = audio_outputs[0].mean(dim=1)
|
| 386 |
+
audio_embedding = audio_embedding / audio_embedding.norm(dim=1, keepdim=True)
|
| 387 |
+
|
| 388 |
+
# Получаем текстовые эмбеддинги
|
| 389 |
+
text_embeddings = text_model.encode(labels, convert_to_tensor=True)
|
| 390 |
+
text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
|
| 391 |
+
|
| 392 |
+
# Вычисляем косинусное сходство
|
| 393 |
+
similarities = cosine_similarity(audio_embedding, text_embeddings).squeeze(0)
|
| 394 |
+
# Применяем softmax для получения вероятностей
|
| 395 |
+
probs = torch.softmax(similarities * 10, dim=0) # Масштабируем для лучшей точности
|
| 396 |
+
|
| 397 |
+
output = "Результаты классификации (через audio + text embeddings):\n"
|
| 398 |
+
for label, prob in zip(labels, probs):
|
| 399 |
+
output += f"{label}: {prob.item():.4f}\n"
|
| 400 |
+
return output
|
| 401 |
+
else:
|
| 402 |
+
# Для других моделей используем CLAP по умолчанию
|
| 403 |
+
from transformers import ClapProcessor, ClapModel
|
| 404 |
+
|
| 405 |
+
cache_key = f"audio_zero_shot_{model_name}"
|
| 406 |
+
cached = model_cache.get(cache_key)
|
| 407 |
+
if cached is None:
|
| 408 |
+
processor = ClapProcessor.from_pretrained(model_name)
|
| 409 |
+
model = ClapModel.from_pretrained(model_name)
|
| 410 |
+
cached = (processor, model)
|
| 411 |
+
model_cache.put(cache_key, cached)
|
| 412 |
+
|
| 413 |
+
processor, model = cached
|
| 414 |
+
inputs = processor(text=labels, audios=audio_data, return_tensors="pt", padding=True)
|
| 415 |
+
with torch.no_grad():
|
| 416 |
+
outputs = model(**inputs)
|
| 417 |
+
logits_per_audio = outputs.logits_per_audio
|
| 418 |
+
probs = logits_per_audio.softmax(dim=1)
|
| 419 |
+
|
| 420 |
+
output = "Результаты классификации:\n"
|
| 421 |
+
for label, prob in zip(labels, probs[0]):
|
| 422 |
+
output += f"{label}: {prob.item():.4f}\n"
|
| 423 |
+
return output
|
| 424 |
except Exception as e:
|
| 425 |
return f"Ошибка: {str(e)}"
|
| 426 |
|
|
|
|
| 1330 |
value="music, speech, noise"
|
| 1331 |
)
|
| 1332 |
zs_audio_model = gr.Dropdown(
|
| 1333 |
+
choices=[
|
| 1334 |
+
"laion/clap-htsat-unfused",
|
| 1335 |
+
"laion/clap-htsat-fused",
|
| 1336 |
+
"facebook/wav2vec2-base-960h"
|
| 1337 |
+
],
|
| 1338 |
value="laion/clap-htsat-unfused",
|
| 1339 |
label="Выберите модель"
|
| 1340 |
)
|
|
|
|
| 1383 |
)
|
| 1384 |
tts_model = gr.Dropdown(
|
| 1385 |
choices=[
|
| 1386 |
+
"microsoft/speecht5_tts",
|
| 1387 |
+
"facebook/mms-tts-eng"
|
| 1388 |
],
|
| 1389 |
value="microsoft/speecht5_tts",
|
| 1390 |
+
label="Выберите модель"
|
| 1391 |
)
|
| 1392 |
tts_btn = gr.Button("Синтезировать", variant="primary")
|
| 1393 |
with gr.Column():
|