MinAA commited on
Commit
0c72405
·
1 Parent(s): 5a1bdd1
Files changed (2) hide show
  1. app.py +167 -24
  2. requirements.txt +3 -0
app.py CHANGED
@@ -5,6 +5,7 @@ from transformers import (
5
  CLIPProcessor, CLIPModel, ViltProcessor, ViltForQuestionAnswering
6
  )
7
  import torch
 
8
  from PIL import Image
9
  import functools
10
  import warnings
@@ -296,6 +297,9 @@ def audio_zero_shot_classifier(audio, candidate_labels, model_name):
296
  try:
297
  # Используем CLAP для zero-shot классификации аудио
298
  from transformers import ClapProcessor, ClapModel
 
 
 
299
  cache_key = f"audio_zero_shot_{model_name}"
300
  cached = model_cache.get(cache_key)
301
  if cached is None:
@@ -307,7 +311,27 @@ def audio_zero_shot_classifier(audio, candidate_labels, model_name):
307
  processor, model = cached
308
  labels = [label.strip() for label in candidate_labels.split(",")]
309
 
310
- inputs = processor(text=labels, audios=audio, return_tensors="pt", padding=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  with torch.no_grad():
312
  outputs = model(**inputs)
313
  logits_per_audio = outputs.logits_per_audio
@@ -404,20 +428,75 @@ def image_text_matching(image, text, model_name):
404
  cache_key = f"clip_{model_name}"
405
  cached = model_cache.get(cache_key)
406
  if cached is None:
407
- processor = CLIPProcessor.from_pretrained(model_name)
408
- model = CLIPModel.from_pretrained(model_name)
409
- cached = (processor, model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  model_cache.put(cache_key, cached)
411
 
412
- processor, model = cached
413
- inputs = processor(text=[text], images=image, return_tensors="pt", padding=True)
414
-
415
- with torch.no_grad():
416
- outputs = model(**inputs)
417
- logits_per_image = outputs.logits_per_image
418
- probs = logits_per_image.softmax(dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
- score = probs[0][0].item()
421
  return f"Совпадение изображения и текста: {score:.4f}"
422
  except Exception as e:
423
  return f"Ошибка: {str(e)}"
@@ -499,23 +578,87 @@ def image_zero_shot_classification(image, candidate_labels, model_name):
499
  cache_key = f"clip_zs_{model_name}"
500
  cached = model_cache.get(cache_key)
501
  if cached is None:
502
- processor = CLIPProcessor.from_pretrained(model_name)
503
- model = CLIPModel.from_pretrained(model_name)
504
- cached = (processor, model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  model_cache.put(cache_key, cached)
506
 
507
- processor, model = cached
508
  labels = [label.strip() for label in candidate_labels.split(",")]
509
- inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
510
 
511
- with torch.no_grad():
512
- outputs = model(**inputs)
513
- logits_per_image = outputs.logits_per_image
514
- probs = logits_per_image.softmax(dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
 
516
- output = "Результаты классификации:\n"
517
- for label, prob in zip(labels, probs[0]):
518
- output += f"{label}: {prob.item():.4f}\n"
519
  return output
520
  except Exception as e:
521
  return f"Ошибка: {str(e)}"
 
5
  CLIPProcessor, CLIPModel, ViltProcessor, ViltForQuestionAnswering
6
  )
7
  import torch
8
+ from torch.nn.functional import cosine_similarity
9
  from PIL import Image
10
  import functools
11
  import warnings
 
297
  try:
298
  # Используем CLAP для zero-shot классификации аудио
299
  from transformers import ClapProcessor, ClapModel
300
+ import soundfile as sf
301
+ import numpy as np
302
+
303
  cache_key = f"audio_zero_shot_{model_name}"
304
  cached = model_cache.get(cache_key)
305
  if cached is None:
 
311
  processor, model = cached
312
  labels = [label.strip() for label in candidate_labels.split(",")]
313
 
314
+ # Загружаем аудио файл, если передан путь
315
+ if isinstance(audio, str):
316
+ # audio - это путь к файлу
317
+ audio_data, sample_rate = sf.read(audio)
318
+ elif isinstance(audio, tuple):
319
+ # audio - это кортеж (sample_rate, audio_data) от Gradio
320
+ sample_rate, audio_data = audio
321
+ else:
322
+ # audio уже является массивом numpy
323
+ audio_data = audio
324
+ sample_rate = None
325
+
326
+ # Преобразуем в numpy array, если нужно
327
+ if not isinstance(audio_data, np.ndarray):
328
+ audio_data = np.array(audio_data)
329
+
330
+ # Если аудио моно, убеждаемся что это 1D массив
331
+ if len(audio_data.shape) > 1:
332
+ audio_data = audio_data[:, 0] if audio_data.shape[1] > 0 else audio_data.flatten()
333
+
334
+ inputs = processor(text=labels, audios=audio_data, return_tensors="pt", padding=True)
335
  with torch.no_grad():
336
  outputs = model(**inputs)
337
  logits_per_audio = outputs.logits_per_audio
 
428
  cache_key = f"clip_{model_name}"
429
  cached = model_cache.get(cache_key)
430
  if cached is None:
431
+ # Проверяем, является ли модель из sentence-transformers
432
+ if "sentence-transformers" in model_name:
433
+ from sentence_transformers import SentenceTransformer
434
+ model = SentenceTransformer(model_name)
435
+ cached = ("sentence_transformers", model)
436
+ # Проверяем, является ли модель LAION (требует OpenCLIP)
437
+ elif "laion/" in model_name.lower() or "laion5b" in model_name.lower():
438
+ import open_clip
439
+ # Определяем имя модели и веса для OpenCLIP
440
+ if "xlm-roberta-base-ViT-B-32" in model_name or "xlm-roberta-base" in model_name:
441
+ clip_model_name = "xlm-roberta-base-ViT-B-32"
442
+ pretrained = "laion5b_s13b_b90k"
443
+ else:
444
+ # Пытаемся извлечь информацию из имени модели
445
+ clip_model_name = "xlm-roberta-base-ViT-B-32"
446
+ pretrained = "laion5b_s13b_b90k"
447
+
448
+ model, _, preprocess = open_clip.create_model_and_transforms(
449
+ clip_model_name,
450
+ pretrained=pretrained
451
+ )
452
+ tokenizer = open_clip.get_tokenizer(clip_model_name)
453
+ model.eval()
454
+ cached = ("openclip", model, preprocess, tokenizer)
455
+ else:
456
+ processor = CLIPProcessor.from_pretrained(model_name)
457
+ model = CLIPModel.from_pretrained(model_name)
458
+ cached = ("transformers", processor, model)
459
  model_cache.put(cache_key, cached)
460
 
461
+ if cached[0] == "sentence_transformers":
462
+ # Используем sentence-transformers
463
+ model = cached[1]
464
+ # Вычисляем эмбеддинги изображения и текста
465
+ image_embedding = model.encode(image, convert_to_tensor=True)
466
+ text_embedding = model.encode(text, convert_to_tensor=True)
467
+ # Вычисляем косинусное сходство
468
+ score = cosine_similarity(image_embedding.unsqueeze(0), text_embedding.unsqueeze(0)).item()
469
+ # Нормализуем в диапазон [0, 1] для лучшей интерпретации
470
+ score = (score + 1) / 2
471
+ elif cached[0] == "openclip":
472
+ # Используем OpenCLIP
473
+ model, preprocess, tokenizer = cached[1], cached[2], cached[3]
474
+ # Обрабатываем изображение и текст
475
+ image_tensor = preprocess(image).unsqueeze(0)
476
+ text_tokens = tokenizer([text])
477
+
478
+ with torch.no_grad():
479
+ image_features = model.encode_image(image_tensor)
480
+ text_features = model.encode_text(text_tokens)
481
+ # Нормализуем признаки
482
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
483
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
484
+ # Вычисляем косинусное сходство
485
+ score = (image_features @ text_features.T).item()
486
+ # Нормализуем в диапазон [0, 1] для лучшей интерпретации
487
+ score = (score + 1) / 2
488
+ else:
489
+ # Используем стандартный CLIP из transformers
490
+ processor, model = cached[1], cached[2]
491
+ inputs = processor(text=[text], images=image, return_tensors="pt", padding=True)
492
+
493
+ with torch.no_grad():
494
+ outputs = model(**inputs)
495
+ logits_per_image = outputs.logits_per_image
496
+ probs = logits_per_image.softmax(dim=1)
497
+
498
+ score = probs[0][0].item()
499
 
 
500
  return f"Совпадение изображения и текста: {score:.4f}"
501
  except Exception as e:
502
  return f"Ошибка: {str(e)}"
 
578
  cache_key = f"clip_zs_{model_name}"
579
  cached = model_cache.get(cache_key)
580
  if cached is None:
581
+ # Проверяем, является ли модель из sentence-transformers
582
+ if "sentence-transformers" in model_name:
583
+ from sentence_transformers import SentenceTransformer
584
+ model = SentenceTransformer(model_name)
585
+ cached = ("sentence_transformers", model)
586
+ # Проверяем, является ли модель LAION (требует OpenCLIP)
587
+ elif "laion/" in model_name.lower() or "laion5b" in model_name.lower():
588
+ import open_clip
589
+ # Определяем имя модели и веса для OpenCLIP
590
+ if "xlm-roberta-base-ViT-B-32" in model_name or "xlm-roberta-base" in model_name:
591
+ clip_model_name = "xlm-roberta-base-ViT-B-32"
592
+ pretrained = "laion5b_s13b_b90k"
593
+ else:
594
+ # Пытаемся извлечь информацию из имени модели
595
+ clip_model_name = "xlm-roberta-base-ViT-B-32"
596
+ pretrained = "laion5b_s13b_b90k"
597
+
598
+ model, _, preprocess = open_clip.create_model_and_transforms(
599
+ clip_model_name,
600
+ pretrained=pretrained
601
+ )
602
+ tokenizer = open_clip.get_tokenizer(clip_model_name)
603
+ model.eval()
604
+ cached = ("openclip", model, preprocess, tokenizer)
605
+ else:
606
+ processor = CLIPProcessor.from_pretrained(model_name)
607
+ model = CLIPModel.from_pretrained(model_name)
608
+ cached = ("transformers", processor, model)
609
  model_cache.put(cache_key, cached)
610
 
 
611
  labels = [label.strip() for label in candidate_labels.split(",")]
 
612
 
613
+ if cached[0] == "sentence_transformers":
614
+ # Используем sentence-transformers
615
+ model = cached[1]
616
+ # Вычисляем эмбеддинги изображения и текстов
617
+ image_embedding = model.encode(image, convert_to_tensor=True)
618
+ text_embeddings = model.encode(labels, convert_to_tensor=True)
619
+ # Вычисляем косинусное сходство
620
+ similarities = cosine_similarity(image_embedding.unsqueeze(0), text_embeddings).squeeze(0)
621
+ # Нормализуем в диапазон [0, 1] и применяем softmax для вероятностей
622
+ similarities = (similarities + 1) / 2
623
+ probs = torch.softmax(similarities, dim=0)
624
+
625
+ output = "Результаты классификации:\n"
626
+ for label, prob in zip(labels, probs):
627
+ output += f"{label}: {prob.item():.4f}\n"
628
+ elif cached[0] == "openclip":
629
+ # Используем OpenCLIP
630
+ model, preprocess, tokenizer = cached[1], cached[2], cached[3]
631
+ # Обрабатываем изображение и тексты
632
+ image_tensor = preprocess(image).unsqueeze(0)
633
+ text_tokens = tokenizer(labels)
634
+
635
+ with torch.no_grad():
636
+ image_features = model.encode_image(image_tensor)
637
+ text_features = model.encode_text(text_tokens)
638
+ # Нормализуем признаки
639
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
640
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
641
+ # Вычисляем косинусное сходство (логиты)
642
+ logits_per_image = (image_features @ text_features.T) * 100 # Масштабируем для лучшей точности
643
+ probs = logits_per_image.softmax(dim=1)
644
+
645
+ output = "Результаты классификации:\n"
646
+ for label, prob in zip(labels, probs[0]):
647
+ output += f"{label}: {prob.item():.4f}\n"
648
+ else:
649
+ # Используем стандартный CLIP из transformers
650
+ processor, model = cached[1], cached[2]
651
+ inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
652
+
653
+ with torch.no_grad():
654
+ outputs = model(**inputs)
655
+ logits_per_image = outputs.logits_per_image
656
+ probs = logits_per_image.softmax(dim=1)
657
+
658
+ output = "Результаты классификации:\n"
659
+ for label, prob in zip(labels, probs[0]):
660
+ output += f"{label}: {prob.item():.4f}\n"
661
 
 
 
 
662
  return output
663
  except Exception as e:
664
  return f"Ошибка: {str(e)}"
requirements.txt CHANGED
@@ -8,3 +8,6 @@ soundfile>=0.12.0
8
  accelerate>=0.20.0
9
  sentencepiece>=0.1.99
10
  datasets>=2.14.0
 
 
 
 
8
  accelerate>=0.20.0
9
  sentencepiece>=0.1.99
10
  datasets>=2.14.0
11
+ timm>=0.9.0
12
+ sentence-transformers>=2.2.0
13
+ open-clip-torch>=2.20.0