AlsuGibadullina commited on
Commit
c426f40
·
verified ·
1 Parent(s): 99d0a4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -360
app.py CHANGED
@@ -2,10 +2,8 @@ import tempfile
2
  from typing import List, Tuple, Any
3
 
4
  import gradio as gr
5
- import soundfile as sf
6
  import torch
7
  import torch.nn.functional as torch_functional
8
- from gtts import gTTS
9
  from PIL import Image, ImageDraw
10
  from transformers import (
11
  AutoTokenizer,
@@ -13,26 +11,22 @@ from transformers import (
13
  CLIPProcessor,
14
  SamModel,
15
  SamProcessor,
16
- VitsModel,
17
  pipeline,
18
  BlipForQuestionAnswering,
19
  BlipProcessor,
20
  )
21
 
22
-
23
  MODEL_STORE = {}
24
 
 
25
  def _normalize_gallery_images(gallery_value: Any) -> List[Image.Image]:
26
  if not gallery_value:
27
  return []
28
-
29
  normalized_images: List[Image.Image] = []
30
-
31
  for item in gallery_value:
32
  if isinstance(item, Image.Image):
33
  normalized_images.append(item)
34
  continue
35
-
36
  if isinstance(item, str):
37
  try:
38
  image_object = Image.open(item).convert("RGB")
@@ -40,61 +34,18 @@ def _normalize_gallery_images(gallery_value: Any) -> List[Image.Image]:
40
  except Exception:
41
  continue
42
  continue
43
-
44
  if isinstance(item, (list, tuple)) and item:
45
  candidate = item[0]
46
  if isinstance(candidate, Image.Image):
47
  normalized_images.append(candidate)
48
  continue
49
-
50
  if isinstance(item, dict):
51
  candidate = item.get("image") or item.get("value")
52
  if isinstance(candidate, Image.Image):
53
  normalized_images.append(candidate)
54
  continue
55
-
56
  return normalized_images
57
 
58
- def get_audio_pipeline(model_key: str):
59
- if model_key in MODEL_STORE:
60
- return MODEL_STORE[model_key]
61
-
62
- if model_key == "whisper":
63
- audio_pipeline = pipeline(
64
- task="automatic-speech-recognition",
65
- model="distil-whisper/distil-small.en",
66
- )
67
- elif model_key == "wav2vec2":
68
- audio_pipeline = pipeline(
69
- task="automatic-speech-recognition",
70
- model="openai/whisper-small",
71
- )
72
- elif model_key == "audio_classifier":
73
- audio_pipeline = pipeline(
74
- task="audio-classification",
75
- model="MIT/ast-finetuned-audioset-10-10-0.4593",
76
- )
77
- elif model_key == "emotion_classifier":
78
- audio_pipeline = pipeline(
79
- task="audio-classification",
80
- model="superb/hubert-large-superb-er",
81
- )
82
- else:
83
- raise ValueError(f"Неизвестный тип аудио модели: {model_key}")
84
-
85
- MODEL_STORE[model_key] = audio_pipeline
86
- return audio_pipeline
87
-
88
-
89
- def get_zero_shot_audio_pipeline():
90
- if "audio_zero_shot_clap" not in MODEL_STORE:
91
- zero_shot_pipeline = pipeline(
92
- task="zero-shot-audio-classification",
93
- model="laion/clap-htsat-unfused",
94
- )
95
- MODEL_STORE["audio_zero_shot_clap"] = zero_shot_pipeline
96
- return MODEL_STORE["audio_zero_shot_clap"]
97
-
98
 
99
  def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]:
100
  if "blip_vqa_model" not in MODEL_STORE or "blip_vqa_processor" not in MODEL_STORE:
@@ -102,11 +53,11 @@ def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]:
102
  blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
103
  MODEL_STORE["blip_vqa_model"] = blip_model
104
  MODEL_STORE["blip_vqa_processor"] = blip_processor
105
-
106
  blip_model = MODEL_STORE["blip_vqa_model"]
107
  blip_processor = MODEL_STORE["blip_vqa_processor"]
108
  return blip_model, blip_processor
109
 
 
110
  def get_vision_pipeline(model_key: str):
111
  if model_key in MODEL_STORE:
112
  return MODEL_STORE[model_key]
@@ -121,19 +72,16 @@ def get_vision_pipeline(model_key: str):
121
  task="object-detection",
122
  model="hustvl/yolos-small",
123
  )
124
-
125
  elif model_key == "segmentation":
126
  vision_pipeline = pipeline(
127
  task="image-segmentation",
128
  model="nvidia/segformer-b0-finetuned-ade-512-512",
129
  )
130
-
131
  elif model_key == "depth_estimation":
132
  vision_pipeline = pipeline(
133
  task="depth-estimation",
134
  model="Intel/dpt-hybrid-midas",
135
  )
136
-
137
  elif model_key == "captioning_blip_base":
138
  vision_pipeline = pipeline(
139
  task="image-to-text",
@@ -144,7 +92,6 @@ def get_vision_pipeline(model_key: str):
144
  task="image-to-text",
145
  model="Salesforce/blip-image-captioning-large",
146
  )
147
-
148
  elif model_key == "vqa_blip_base":
149
  vision_pipeline = pipeline(
150
  task="visual-question-answering",
@@ -155,7 +102,6 @@ def get_vision_pipeline(model_key: str):
155
  task="visual-question-answering",
156
  model="dandelin/vilt-b32-finetuned-vqa",
157
  )
158
-
159
  else:
160
  raise ValueError(f"Неизвестный тип визуальной модели: {model_key}")
161
 
@@ -177,7 +123,6 @@ def get_clip_components(clip_key: str) -> Tuple[CLIPModel, CLIPProcessor]:
177
 
178
  clip_model = CLIPModel.from_pretrained(clip_name)
179
  clip_processor = CLIPProcessor.from_pretrained(clip_name)
180
-
181
  MODEL_STORE[model_store_key_model] = clip_model
182
  MODEL_STORE[model_store_key_processor] = clip_processor
183
 
@@ -186,125 +131,26 @@ def get_clip_components(clip_key: str) -> Tuple[CLIPModel, CLIPProcessor]:
186
  return clip_model, clip_processor
187
 
188
 
189
- def get_silero_tts_model():
190
- if "silero_tts_model" not in MODEL_STORE:
191
- silero_model, _ = torch.hub.load(
192
- repo_or_dir="snakers4/silero-models",
193
- model="silero_tts",
194
- language="ru",
195
- speaker="ru_v3",
196
- )
197
- MODEL_STORE["silero_tts_model"] = silero_model
198
- return MODEL_STORE["silero_tts_model"]
199
-
200
-
201
- def get_mms_tts_components():
202
- if "mms_tts_pipeline" not in MODEL_STORE:
203
- tts_pipeline = pipeline(
204
- task="text-to-speech",
205
- model="facebook/mms-tts-rus",
206
- )
207
- MODEL_STORE["mms_tts_pipeline"] = tts_pipeline
208
-
209
- return MODEL_STORE["mms_tts_pipeline"]
210
-
211
-
212
  def get_sam_components() -> Tuple[SamModel, SamProcessor]:
213
  if "sam_model" not in MODEL_STORE or "sam_processor" not in MODEL_STORE:
214
  sam_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
215
  sam_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
216
  MODEL_STORE["sam_model"] = sam_model
217
  MODEL_STORE["sam_processor"] = sam_processor
218
-
219
  sam_model = MODEL_STORE["sam_model"]
220
  sam_processor = MODEL_STORE["sam_processor"]
221
  return sam_model, sam_processor
222
 
223
 
224
-
225
- def classify_audio_file(audio_path: str, model_key: str) -> str:
226
- audio_classifier = get_audio_pipeline(model_key)
227
- prediction_list = audio_classifier(audio_path)
228
-
229
- result_lines = ["Топ-5 предсказаний:"]
230
- for prediction_index, prediction_item in enumerate(prediction_list[:5], start=1):
231
- label_value = prediction_item["label"]
232
- score_value = prediction_item["score"]
233
- result_lines.append(
234
- f"{prediction_index}. {label_value}: {score_value:.4f}"
235
- )
236
-
237
- return "\n".join(result_lines)
238
-
239
-
240
- def classify_audio_zero_shot_clap(audio_path: str, label_texts: str) -> str:
241
-
242
- clap_pipeline = get_zero_shot_audio_pipeline()
243
-
244
- label_list = [
245
- label_item.strip()
246
- for label_item in label_texts.split(",")
247
- if label_item.strip()
248
- ]
249
- if not label_list:
250
- return "Не задано ни одной текстовой метки для zero-shot классификации."
251
-
252
- prediction_list = clap_pipeline(
253
- audio_path,
254
- candidate_labels=label_list,
255
- )
256
-
257
- result_lines = ["Zero-Shot Audio Classification (CLAP):"]
258
- for prediction_index, prediction_item in enumerate(prediction_list, start=1):
259
- label_value = prediction_item["label"]
260
- score_value = prediction_item["score"]
261
- result_lines.append(
262
- f"{prediction_index}. {label_value}: {score_value:.4f}"
263
- )
264
-
265
- return "\n".join(result_lines)
266
-
267
-
268
- def recognize_speech(audio_path: str, model_key: str) -> str:
269
- speech_pipeline = get_audio_pipeline(model_key)
270
-
271
- prediction_result = speech_pipeline(audio_path)
272
-
273
- return prediction_result["text"]
274
-
275
-
276
- def synthesize_speech(text_value: str, model_key: str):
277
- if model_key == "Google TTS":
278
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
279
- text_to_speech_engine = gTTS(text=text_value, lang="ru")
280
- text_to_speech_engine.save(file_object.name)
281
- return file_object.name
282
- elif model_key == "mms":
283
- model = VitsModel.from_pretrained("facebook/mms-tts-rus")
284
- tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
285
-
286
- inputs = tokenizer(text_value, return_tensors="pt")
287
- with torch.no_grad():
288
- output = model(**inputs).waveform
289
-
290
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
291
- sf.write(f.name, output.numpy().squeeze(), model.config.sampling_rate)
292
- return f.name
293
-
294
- raise ValueError(f"Неизвестная модель: {model_key}")
295
-
296
-
297
-
298
  def detect_objects_on_image(image_object, model_key: str):
299
  detector_pipeline = get_vision_pipeline(model_key)
300
  detection_results = detector_pipeline(image_object)
301
-
302
  drawer_object = ImageDraw.Draw(image_object)
 
303
  for detection_item in detection_results:
304
  box_data = detection_item["box"]
305
  label_value = detection_item["label"]
306
  score_value = detection_item["score"]
307
-
308
  drawer_object.rectangle(
309
  [
310
  box_data["xmin"],
@@ -320,7 +166,6 @@ def detect_objects_on_image(image_object, model_key: str):
320
  f"{label_value}: {score_value:.2f}",
321
  fill="red",
322
  )
323
-
324
  return image_object
325
 
326
 
@@ -333,7 +178,6 @@ def segment_image(image_object):
333
  def estimate_image_depth(image_object):
334
  depth_pipeline = get_vision_pipeline("depth_estimation")
335
  depth_output = depth_pipeline(image_object)
336
-
337
  predicted_depth_tensor = depth_output["predicted_depth"]
338
 
339
  if predicted_depth_tensor.ndim == 3:
@@ -351,10 +195,8 @@ def estimate_image_depth(image_object):
351
  mode="bicubic",
352
  align_corners=False,
353
  )
354
-
355
  depth_array = resized_depth_tensor.squeeze().cpu().numpy()
356
  max_value = float(depth_array.max())
357
-
358
  if max_value <= 0.0:
359
  return Image.new("L", image_object.size, color=0)
360
 
@@ -372,50 +214,42 @@ def generate_image_caption(image_object, model_key: str) -> str:
372
  def answer_visual_question(image_object, question_text: str, model_key: str) -> str:
373
  if image_object is None:
374
  return "Пожалуйста, сначала загрузите изображение."
375
-
376
  if not question_text.strip():
377
  return "Пожалуйста, введите вопрос об изображении."
378
 
379
  if model_key == "vqa_blip_base":
380
  blip_model, blip_processor = get_blip_vqa_components()
381
-
382
  inputs = blip_processor(
383
  images=image_object,
384
  text=question_text,
385
  return_tensors="pt",
386
  )
387
-
388
  with torch.no_grad():
389
  output_ids = blip_model.generate(**inputs)
390
-
391
  decoded_answers = blip_processor.batch_decode(
392
  output_ids,
393
  skip_special_tokens=True,
394
  )
395
  answer_text = decoded_answers[0] if decoded_answers else ""
396
-
397
  return answer_text or "Модель не смогла сгенерировать ответ."
398
 
399
  vqa_pipeline = get_vision_pipeline(model_key)
400
-
401
  vqa_result = vqa_pipeline(
402
  image=image_object,
403
  question=question_text,
404
  )
405
-
406
  top_item = vqa_result[0]
407
  answer_text = top_item["answer"]
408
  confidence_value = top_item["score"]
409
-
410
  return f"{answer_text} (confidence: {confidence_value:.3f})"
411
 
 
412
  def perform_zero_shot_classification(
413
  image_object,
414
  class_texts: str,
415
  clip_key: str,
416
  ) -> str:
417
  clip_model, clip_processor = get_clip_components(clip_key)
418
-
419
  class_list = [
420
  class_name.strip()
421
  for class_name in class_texts.split(",")
@@ -430,7 +264,6 @@ def perform_zero_shot_classification(
430
  return_tensors="pt",
431
  padding=True,
432
  )
433
-
434
  with torch.no_grad():
435
  clip_outputs = clip_model(**input_batch)
436
  logits_per_image = clip_outputs.logits_per_image
@@ -440,7 +273,6 @@ def perform_zero_shot_classification(
440
  for class_index, class_name in enumerate(class_list):
441
  probability_value = probability_tensor[0][class_index].item()
442
  result_lines.append(f"{class_name}: {probability_value:.4f}")
443
-
444
  return "\n".join(result_lines)
445
 
446
 
@@ -450,12 +282,10 @@ def retrieve_best_image(
450
  clip_key: str,
451
  ) -> Tuple[str, Image.Image | None]:
452
  image_list = _normalize_gallery_images(gallery_value)
453
-
454
  if not image_list or not query_text.strip():
455
  return "Пожалуйста, загрузите изображения и введите запрос", None
456
 
457
  clip_model, clip_processor = get_clip_components(clip_key)
458
-
459
  image_inputs = clip_processor(
460
  images=image_list,
461
  return_tensors="pt",
@@ -463,10 +293,10 @@ def retrieve_best_image(
463
  )
464
  with torch.no_grad():
465
  image_features = clip_model.get_image_features(**image_inputs)
466
- image_features = image_features / image_features.norm(
467
- dim=-1,
468
- keepdim=True,
469
- )
470
 
471
  text_inputs = clip_processor(
472
  text=[query_text],
@@ -475,10 +305,10 @@ def retrieve_best_image(
475
  )
476
  with torch.no_grad():
477
  text_features = clip_model.get_text_features(**text_inputs)
478
- text_features = text_features / text_features.norm(
479
- dim=-1,
480
- keepdim=True,
481
- )
482
 
483
  similarity_tensor = image_features @ text_features.T
484
  best_index_tensor = similarity_tensor.argmax()
@@ -498,12 +328,10 @@ def segment_image_with_sam_points(
498
  ) -> Image.Image:
499
  if image_object is None:
500
  raise ValueError("Изображение не передано в segment_image_with_sam_points")
501
-
502
  if not point_coordinates_list:
503
  return Image.new("L", image_object.size, color=0)
504
 
505
  sam_model, sam_processor = get_sam_components()
506
-
507
  batched_points: List[List[List[int]]] = [point_coordinates_list]
508
  batched_labels: List[List[int]] = [[1 for _ in point_coordinates_list]]
509
 
@@ -513,7 +341,6 @@ def segment_image_with_sam_points(
513
  input_labels=batched_labels,
514
  return_tensors="pt",
515
  )
516
-
517
  with torch.no_grad():
518
  sam_outputs = sam_model(**sam_inputs, multimask_output=True)
519
 
@@ -522,47 +349,37 @@ def segment_image_with_sam_points(
522
  sam_inputs["original_sizes"].cpu(),
523
  sam_inputs["reshaped_input_sizes"].cpu(),
524
  )
525
-
526
  batch_masks_tensor = processed_masks_list[0]
527
-
528
  if batch_masks_tensor.ndim != 3 or batch_masks_tensor.shape[0] == 0:
529
  return Image.new("L", image_object.size, color=0)
530
 
531
  first_mask_tensor = batch_masks_tensor[0]
532
  mask_array = first_mask_tensor.numpy()
533
-
534
  binary_mask_array = (mask_array > 0.5).astype("uint8") * 255
535
-
536
  mask_image = Image.fromarray(binary_mask_array, mode="L")
537
  return mask_image
538
 
539
 
540
  def segment_image_with_sam_points_ui(image_object, coordinates_text: str) -> Image.Image:
541
-
542
  if image_object is None:
543
  return None
544
-
545
  coordinates_text_clean = coordinates_text.strip()
546
  if not coordinates_text_clean:
547
  return Image.new("L", image_object.size, color=0)
548
 
549
  point_coordinates_list: List[List[int]] = []
550
-
551
  for raw_pair in coordinates_text_clean.replace("\n", ";").split(";"):
552
  raw_pair_clean = raw_pair.strip()
553
  if not raw_pair_clean:
554
  continue
555
-
556
  parts = raw_pair_clean.split(",")
557
  if len(parts) != 2:
558
  continue
559
-
560
  try:
561
  x_value = int(parts[0].strip())
562
  y_value = int(parts[1].strip())
563
  except ValueError:
564
  continue
565
-
566
  point_coordinates_list.append([x_value, y_value])
567
 
568
  if not point_coordinates_list:
@@ -574,7 +391,6 @@ def segment_image_with_sam_points_ui(image_object, coordinates_text: str) -> Ima
574
  def parse_point_coordinates_text(coordinates_text: str) -> List[List[int]]:
575
  if not coordinates_text.strip():
576
  return []
577
-
578
  point_list: List[List[int]] = []
579
  for raw_pair in coordinates_text.split(";"):
580
  cleaned_pair = raw_pair.strip()
@@ -589,125 +405,13 @@ def parse_point_coordinates_text(coordinates_text: str) -> List[List[int]]:
589
  except ValueError:
590
  continue
591
  point_list.append([x_value, y_value])
592
-
593
  return point_list
594
 
 
595
  def build_interface():
596
  with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo_block:
597
  gr.Markdown("# AI модели")
598
 
599
- with gr.Tab("Классификация аудио"):
600
- gr.Markdown("## Классификация аудио")
601
- with gr.Row():
602
- audio_input_component = gr.Audio(
603
- label="Загрузите аудиофайл",
604
- type="filepath",
605
- )
606
- audio_model_selector = gr.Dropdown(
607
- choices=["audio_classifier", "emotion_classifier"],
608
- label="Выберите модель",
609
- value="audio_classifier",
610
- info=(
611
- "audio_classifier - общая классификация (курс)"
612
- "emotion_classifier - эмоции в речи "
613
- ),
614
- )
615
- audio_classify_button = gr.Button("Применить")
616
-
617
- audio_output_component = gr.Textbox(
618
- label="Результаты классификации",
619
- lines=10,
620
- )
621
-
622
- audio_classify_button.click(
623
- fn=classify_audio_file,
624
- inputs=[audio_input_component, audio_model_selector],
625
- outputs=audio_output_component,
626
- )
627
-
628
- with gr.Tab("Zero-Shot аудио"):
629
- gr.Markdown("## Zero-Shot аудио классификатор")
630
- with gr.Row():
631
- clap_audio_input_component = gr.Audio(
632
- label="Загрузите аудиофайл",
633
- type="filepath",
634
- )
635
- clap_label_texts_component = gr.Textbox(
636
- label="Кандидатные метки (через запятую)",
637
- placeholder="лай собаки, шум дождя, музыка, разговор",
638
- lines=2,
639
- )
640
- clap_button = gr.Button("Применить")
641
-
642
- clap_output_component = gr.Textbox(
643
- label="Результаты zero-shot классификации",
644
- lines=10,
645
- )
646
-
647
- clap_button.click(
648
- fn=classify_audio_zero_shot_clap,
649
- inputs=[clap_audio_input_component, clap_label_texts_component],
650
- outputs=clap_output_component,
651
- )
652
-
653
- with gr.Tab("Распознавание речи"):
654
- gr.Markdown("## Распознавание реч")
655
- with gr.Row():
656
- asr_audio_input_component = gr.Audio(
657
- label="Загрузите аудио с речью",
658
- type="filepath",
659
- )
660
- asr_model_selector = gr.Dropdown(
661
- choices=["whisper", "wav2vec2"],
662
- label="Выберите модель",
663
- value="whisper",
664
- info=(
665
- "whisper - distil-whisper/distil-small.en (курс),\n"
666
- "wav2vec2 - openai/whisper-small"
667
- ),
668
- )
669
- asr_button = gr.Button("Применить")
670
-
671
- asr_output_component = gr.Textbox(
672
- label="Транскрипция",
673
- lines=5,
674
- )
675
-
676
- asr_button.click(
677
- fn=recognize_speech,
678
- inputs=[asr_audio_input_component, asr_model_selector],
679
- outputs=asr_output_component,
680
- )
681
- with gr.Tab("Синтез речи"):
682
- gr.Markdown("## Text-to-Speech")
683
- with gr.Row():
684
- tts_text_component = gr.Textbox(
685
- label="Введите текст для синтеза",
686
- placeholder="Введите текст на русском или английском языке...",
687
- lines=3,
688
- )
689
- tts_model_selector = gr.Dropdown(
690
- choices=["mms", "Google TTS"],
691
- label="Выберите модель",
692
- value="mms",
693
- info=(
694
- "facebook/mms-tts-rus\n"
695
- "Google TTS"
696
- ),
697
- )
698
- tts_button = gr.Button("Применить")
699
-
700
- tts_audio_output_component = gr.Audio(
701
- label="Синтезированная речь",
702
- type="filepath",
703
- )
704
-
705
- tts_button.click(
706
- fn=synthesize_speech,
707
- inputs=[tts_text_component, tts_model_selector],
708
- outputs=tts_audio_output_component,
709
- )
710
-
711
  with gr.Tab("Детекция объектов"):
712
  gr.Markdown("## Детекция объектов")
713
  with gr.Row():
@@ -724,15 +428,13 @@ def build_interface():
724
  value="object_detection_conditional_detr",
725
  info=(
726
  "object_detection_conditional_detr - microsoft/conditional-detr-resnet-50\n"
727
- "object_detection_yolos_small - hustvl/yolos-small"
728
  ),
729
  )
730
  object_detect_button = gr.Button("Применить")
731
-
732
- object_output_image = gr.Image(
733
- label="Результат",
734
- )
735
-
736
  object_detect_button.click(
737
  fn=detect_objects_on_image,
738
  inputs=[object_input_image, object_model_selector],
@@ -747,11 +449,9 @@ def build_interface():
747
  type="pil",
748
  )
749
  segmentation_button = gr.Button("Применить")
750
-
751
- segmentation_output_image = gr.Image(
752
- label="Маска",
753
- )
754
-
755
  segmentation_button.click(
756
  fn=segment_image,
757
  inputs=segmentation_input_image,
@@ -761,17 +461,14 @@ def build_interface():
761
  with gr.Tab("Глубина"):
762
  gr.Markdown("## Глубина (Depth Estimation)")
763
  with gr.Row():
764
-
765
  depth_input_image = gr.Image(
766
  label="Загрузите изображение",
767
  type="pil",
768
  )
769
  depth_button = gr.Button("Применить")
770
-
771
- depth_output_image = gr.Image(
772
- label="Глубины",
773
- )
774
-
775
  depth_button.click(
776
  fn=estimate_image_depth,
777
  inputs=depth_input_image,
@@ -793,17 +490,15 @@ def build_interface():
793
  label="Модель",
794
  value="captioning_blip_base",
795
  info=(
796
- "captioning_blip_base - Salesforce/blip-image-captioning-base (курс)\n"
797
  "captioning_blip_large - Salesforce/blip-image-captioning-large"
798
  ),
799
  )
800
  caption_button = gr.Button("Применить")
801
-
802
- caption_output_text = gr.Textbox(
803
- label="Описание изображения",
804
- lines=3,
805
- )
806
-
807
  caption_button.click(
808
  fn=generate_image_caption,
809
  inputs=[caption_input_image, caption_model_selector],
@@ -831,16 +526,14 @@ def build_interface():
831
  value="vqa_blip_base",
832
  info=(
833
  "vqa_blip_base - Salesforce/blip-vqa-base (курс)\n"
834
- "vqa_vilt_b32 - dandelin/vilt-b32-finetuned-vqa"
835
  ),
836
  )
837
  vqa_button = gr.Button("Ответить на вопрос")
838
-
839
- vqa_output_text = gr.Textbox(
840
- label="Ответ",
841
- lines=3,
842
- )
843
-
844
  vqa_button.click(
845
  fn=answer_visual_question,
846
  inputs=[vqa_input_image, vqa_question_text, vqa_model_selector],
@@ -868,16 +561,14 @@ def build_interface():
868
  value="clip_large_patch14",
869
  info=(
870
  "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n"
871
- "clip_base_patch32 - openai/clip-vit-base-patch32"
872
  ),
873
  )
874
  zero_shot_button = gr.Button("Применить")
875
-
876
- zero_shot_output_text = gr.Textbox(
877
- label="Результаты",
878
- lines=10,
879
- )
880
-
881
  zero_shot_button.click(
882
  fn=perform_zero_shot_classification,
883
  inputs=[zero_shot_input_image, zero_shot_classes_text, clip_model_selector],
@@ -887,7 +578,6 @@ def build_interface():
887
  with gr.Tab("Поиск изображений"):
888
  gr.Markdown("## Поиск изображений")
889
  with gr.Row():
890
-
891
  retrieval_dir = gr.File(
892
  label="Загрузите папку с изображениями",
893
  file_count="directory",
@@ -908,18 +598,16 @@ def build_interface():
908
  value="clip_large_patch14",
909
  info=(
910
  "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n"
911
- "clip_base_patch32 - openai/clip-vit-base-patch32 (альтернатива)"
912
  ),
913
  )
914
  retrieval_button = gr.Button("Поиск")
915
-
916
- retrieval_output_text = gr.Textbox(
917
- label="Результат",
918
- )
919
- retrieval_output_image = gr.Image(
920
- label="Наиболее подходящее изображение",
921
- )
922
-
923
  retrieval_button.click(
924
  fn=retrieve_best_image,
925
  inputs=[retrieval_dir, retrieval_query_text, retrieval_clip_selector],
@@ -930,11 +618,11 @@ def build_interface():
930
  gr.Markdown("### Задачи:")
931
  gr.Markdown(
932
  """
933
- - Аудио: классификация, распознавание речи, синтез речи
934
- - Компьютерное зрение: детекция объектов, сегментация, оценка глубины, генерация описаний изображений
935
  - Мультимодальные задачи: вопросы к изображению, zero-shot классификация изображений, поиск по изображениям по текстовому запросу
936
- """
937
  )
 
938
  return demo_block
939
 
940
 
 
2
  from typing import List, Tuple, Any
3
 
4
  import gradio as gr
 
5
  import torch
6
  import torch.nn.functional as torch_functional
 
7
  from PIL import Image, ImageDraw
8
  from transformers import (
9
  AutoTokenizer,
 
11
  CLIPProcessor,
12
  SamModel,
13
  SamProcessor,
 
14
  pipeline,
15
  BlipForQuestionAnswering,
16
  BlipProcessor,
17
  )
18
 
 
19
  MODEL_STORE = {}
20
 
21
+
22
  def _normalize_gallery_images(gallery_value: Any) -> List[Image.Image]:
23
  if not gallery_value:
24
  return []
 
25
  normalized_images: List[Image.Image] = []
 
26
  for item in gallery_value:
27
  if isinstance(item, Image.Image):
28
  normalized_images.append(item)
29
  continue
 
30
  if isinstance(item, str):
31
  try:
32
  image_object = Image.open(item).convert("RGB")
 
34
  except Exception:
35
  continue
36
  continue
 
37
  if isinstance(item, (list, tuple)) and item:
38
  candidate = item[0]
39
  if isinstance(candidate, Image.Image):
40
  normalized_images.append(candidate)
41
  continue
 
42
  if isinstance(item, dict):
43
  candidate = item.get("image") or item.get("value")
44
  if isinstance(candidate, Image.Image):
45
  normalized_images.append(candidate)
46
  continue
 
47
  return normalized_images
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]:
51
  if "blip_vqa_model" not in MODEL_STORE or "blip_vqa_processor" not in MODEL_STORE:
 
53
  blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
54
  MODEL_STORE["blip_vqa_model"] = blip_model
55
  MODEL_STORE["blip_vqa_processor"] = blip_processor
 
56
  blip_model = MODEL_STORE["blip_vqa_model"]
57
  blip_processor = MODEL_STORE["blip_vqa_processor"]
58
  return blip_model, blip_processor
59
 
60
+
61
  def get_vision_pipeline(model_key: str):
62
  if model_key in MODEL_STORE:
63
  return MODEL_STORE[model_key]
 
72
  task="object-detection",
73
  model="hustvl/yolos-small",
74
  )
 
75
  elif model_key == "segmentation":
76
  vision_pipeline = pipeline(
77
  task="image-segmentation",
78
  model="nvidia/segformer-b0-finetuned-ade-512-512",
79
  )
 
80
  elif model_key == "depth_estimation":
81
  vision_pipeline = pipeline(
82
  task="depth-estimation",
83
  model="Intel/dpt-hybrid-midas",
84
  )
 
85
  elif model_key == "captioning_blip_base":
86
  vision_pipeline = pipeline(
87
  task="image-to-text",
 
92
  task="image-to-text",
93
  model="Salesforce/blip-image-captioning-large",
94
  )
 
95
  elif model_key == "vqa_blip_base":
96
  vision_pipeline = pipeline(
97
  task="visual-question-answering",
 
102
  task="visual-question-answering",
103
  model="dandelin/vilt-b32-finetuned-vqa",
104
  )
 
105
  else:
106
  raise ValueError(f"Неизвестный тип визуальной модели: {model_key}")
107
 
 
123
 
124
  clip_model = CLIPModel.from_pretrained(clip_name)
125
  clip_processor = CLIPProcessor.from_pretrained(clip_name)
 
126
  MODEL_STORE[model_store_key_model] = clip_model
127
  MODEL_STORE[model_store_key_processor] = clip_processor
128
 
 
131
  return clip_model, clip_processor
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def get_sam_components() -> Tuple[SamModel, SamProcessor]:
135
  if "sam_model" not in MODEL_STORE or "sam_processor" not in MODEL_STORE:
136
  sam_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
137
  sam_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
138
  MODEL_STORE["sam_model"] = sam_model
139
  MODEL_STORE["sam_processor"] = sam_processor
 
140
  sam_model = MODEL_STORE["sam_model"]
141
  sam_processor = MODEL_STORE["sam_processor"]
142
  return sam_model, sam_processor
143
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def detect_objects_on_image(image_object, model_key: str):
146
  detector_pipeline = get_vision_pipeline(model_key)
147
  detection_results = detector_pipeline(image_object)
 
148
  drawer_object = ImageDraw.Draw(image_object)
149
+
150
  for detection_item in detection_results:
151
  box_data = detection_item["box"]
152
  label_value = detection_item["label"]
153
  score_value = detection_item["score"]
 
154
  drawer_object.rectangle(
155
  [
156
  box_data["xmin"],
 
166
  f"{label_value}: {score_value:.2f}",
167
  fill="red",
168
  )
 
169
  return image_object
170
 
171
 
 
178
  def estimate_image_depth(image_object):
179
  depth_pipeline = get_vision_pipeline("depth_estimation")
180
  depth_output = depth_pipeline(image_object)
 
181
  predicted_depth_tensor = depth_output["predicted_depth"]
182
 
183
  if predicted_depth_tensor.ndim == 3:
 
195
  mode="bicubic",
196
  align_corners=False,
197
  )
 
198
  depth_array = resized_depth_tensor.squeeze().cpu().numpy()
199
  max_value = float(depth_array.max())
 
200
  if max_value <= 0.0:
201
  return Image.new("L", image_object.size, color=0)
202
 
 
214
  def answer_visual_question(image_object, question_text: str, model_key: str) -> str:
215
  if image_object is None:
216
  return "Пожалуйста, сначала загрузите изображение."
 
217
  if not question_text.strip():
218
  return "Пожалуйста, введите вопрос об изображении."
219
 
220
  if model_key == "vqa_blip_base":
221
  blip_model, blip_processor = get_blip_vqa_components()
 
222
  inputs = blip_processor(
223
  images=image_object,
224
  text=question_text,
225
  return_tensors="pt",
226
  )
 
227
  with torch.no_grad():
228
  output_ids = blip_model.generate(**inputs)
 
229
  decoded_answers = blip_processor.batch_decode(
230
  output_ids,
231
  skip_special_tokens=True,
232
  )
233
  answer_text = decoded_answers[0] if decoded_answers else ""
 
234
  return answer_text or "Модель не смогла сгенерировать ответ."
235
 
236
  vqa_pipeline = get_vision_pipeline(model_key)
 
237
  vqa_result = vqa_pipeline(
238
  image=image_object,
239
  question=question_text,
240
  )
 
241
  top_item = vqa_result[0]
242
  answer_text = top_item["answer"]
243
  confidence_value = top_item["score"]
 
244
  return f"{answer_text} (confidence: {confidence_value:.3f})"
245
 
246
+
247
  def perform_zero_shot_classification(
248
  image_object,
249
  class_texts: str,
250
  clip_key: str,
251
  ) -> str:
252
  clip_model, clip_processor = get_clip_components(clip_key)
 
253
  class_list = [
254
  class_name.strip()
255
  for class_name in class_texts.split(",")
 
264
  return_tensors="pt",
265
  padding=True,
266
  )
 
267
  with torch.no_grad():
268
  clip_outputs = clip_model(**input_batch)
269
  logits_per_image = clip_outputs.logits_per_image
 
273
  for class_index, class_name in enumerate(class_list):
274
  probability_value = probability_tensor[0][class_index].item()
275
  result_lines.append(f"{class_name}: {probability_value:.4f}")
 
276
  return "\n".join(result_lines)
277
 
278
 
 
282
  clip_key: str,
283
  ) -> Tuple[str, Image.Image | None]:
284
  image_list = _normalize_gallery_images(gallery_value)
 
285
  if not image_list or not query_text.strip():
286
  return "Пожалуйста, загрузите изображения и введите запрос", None
287
 
288
  clip_model, clip_processor = get_clip_components(clip_key)
 
289
  image_inputs = clip_processor(
290
  images=image_list,
291
  return_tensors="pt",
 
293
  )
294
  with torch.no_grad():
295
  image_features = clip_model.get_image_features(**image_inputs)
296
+ image_features = image_features / image_features.norm(
297
+ dim=-1,
298
+ keepdim=True,
299
+ )
300
 
301
  text_inputs = clip_processor(
302
  text=[query_text],
 
305
  )
306
  with torch.no_grad():
307
  text_features = clip_model.get_text_features(**text_inputs)
308
+ text_features = text_features / text_features.norm(
309
+ dim=-1,
310
+ keepdim=True,
311
+ )
312
 
313
  similarity_tensor = image_features @ text_features.T
314
  best_index_tensor = similarity_tensor.argmax()
 
328
  ) -> Image.Image:
329
  if image_object is None:
330
  raise ValueError("Изображение не передано в segment_image_with_sam_points")
 
331
  if not point_coordinates_list:
332
  return Image.new("L", image_object.size, color=0)
333
 
334
  sam_model, sam_processor = get_sam_components()
 
335
  batched_points: List[List[List[int]]] = [point_coordinates_list]
336
  batched_labels: List[List[int]] = [[1 for _ in point_coordinates_list]]
337
 
 
341
  input_labels=batched_labels,
342
  return_tensors="pt",
343
  )
 
344
  with torch.no_grad():
345
  sam_outputs = sam_model(**sam_inputs, multimask_output=True)
346
 
 
349
  sam_inputs["original_sizes"].cpu(),
350
  sam_inputs["reshaped_input_sizes"].cpu(),
351
  )
 
352
  batch_masks_tensor = processed_masks_list[0]
 
353
  if batch_masks_tensor.ndim != 3 or batch_masks_tensor.shape[0] == 0:
354
  return Image.new("L", image_object.size, color=0)
355
 
356
  first_mask_tensor = batch_masks_tensor[0]
357
  mask_array = first_mask_tensor.numpy()
 
358
  binary_mask_array = (mask_array > 0.5).astype("uint8") * 255
 
359
  mask_image = Image.fromarray(binary_mask_array, mode="L")
360
  return mask_image
361
 
362
 
363
  def segment_image_with_sam_points_ui(image_object, coordinates_text: str) -> Image.Image:
 
364
  if image_object is None:
365
  return None
 
366
  coordinates_text_clean = coordinates_text.strip()
367
  if not coordinates_text_clean:
368
  return Image.new("L", image_object.size, color=0)
369
 
370
  point_coordinates_list: List[List[int]] = []
 
371
  for raw_pair in coordinates_text_clean.replace("\n", ";").split(";"):
372
  raw_pair_clean = raw_pair.strip()
373
  if not raw_pair_clean:
374
  continue
 
375
  parts = raw_pair_clean.split(",")
376
  if len(parts) != 2:
377
  continue
 
378
  try:
379
  x_value = int(parts[0].strip())
380
  y_value = int(parts[1].strip())
381
  except ValueError:
382
  continue
 
383
  point_coordinates_list.append([x_value, y_value])
384
 
385
  if not point_coordinates_list:
 
391
  def parse_point_coordinates_text(coordinates_text: str) -> List[List[int]]:
392
  if not coordinates_text.strip():
393
  return []
 
394
  point_list: List[List[int]] = []
395
  for raw_pair in coordinates_text.split(";"):
396
  cleaned_pair = raw_pair.strip()
 
405
  except ValueError:
406
  continue
407
  point_list.append([x_value, y_value])
 
408
  return point_list
409
 
410
+
411
  def build_interface():
412
  with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo_block:
413
  gr.Markdown("# AI модели")
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  with gr.Tab("Детекция объектов"):
416
  gr.Markdown("## Детекция объектов")
417
  with gr.Row():
 
428
  value="object_detection_conditional_detr",
429
  info=(
430
  "object_detection_conditional_detr - microsoft/conditional-detr-resnet-50\n"
431
+ "object_detection_yolos_small - hustvl/yolos-small"
432
  ),
433
  )
434
  object_detect_button = gr.Button("Применить")
435
+ object_output_image = gr.Image(
436
+ label="Результат",
437
+ )
 
 
438
  object_detect_button.click(
439
  fn=detect_objects_on_image,
440
  inputs=[object_input_image, object_model_selector],
 
449
  type="pil",
450
  )
451
  segmentation_button = gr.Button("Применить")
452
+ segmentation_output_image = gr.Image(
453
+ label="Маска",
454
+ )
 
 
455
  segmentation_button.click(
456
  fn=segment_image,
457
  inputs=segmentation_input_image,
 
461
  with gr.Tab("Глубина"):
462
  gr.Markdown("## Глубина (Depth Estimation)")
463
  with gr.Row():
 
464
  depth_input_image = gr.Image(
465
  label="Загрузите изображение",
466
  type="pil",
467
  )
468
  depth_button = gr.Button("Применить")
469
+ depth_output_image = gr.Image(
470
+ label="Глубины",
471
+ )
 
 
472
  depth_button.click(
473
  fn=estimate_image_depth,
474
  inputs=depth_input_image,
 
490
  label="Модель",
491
  value="captioning_blip_base",
492
  info=(
493
+ "captioning_blip_base - Salesforce/blip-image-captioning-base (курс)\n"
494
  "captioning_blip_large - Salesforce/blip-image-captioning-large"
495
  ),
496
  )
497
  caption_button = gr.Button("Применить")
498
+ caption_output_text = gr.Textbox(
499
+ label="Описание изображения",
500
+ lines=3,
501
+ )
 
 
502
  caption_button.click(
503
  fn=generate_image_caption,
504
  inputs=[caption_input_image, caption_model_selector],
 
526
  value="vqa_blip_base",
527
  info=(
528
  "vqa_blip_base - Salesforce/blip-vqa-base (курс)\n"
529
+ "vqa_vilt_b32 - dandelin/vilt-b32-finetuned-vqa"
530
  ),
531
  )
532
  vqa_button = gr.Button("Ответить на вопрос")
533
+ vqa_output_text = gr.Textbox(
534
+ label="Ответ",
535
+ lines=3,
536
+ )
 
 
537
  vqa_button.click(
538
  fn=answer_visual_question,
539
  inputs=[vqa_input_image, vqa_question_text, vqa_model_selector],
 
561
  value="clip_large_patch14",
562
  info=(
563
  "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n"
564
+ "clip_base_patch32 - openai/clip-vit-base-patch32"
565
  ),
566
  )
567
  zero_shot_button = gr.Button("Применить")
568
+ zero_shot_output_text = gr.Textbox(
569
+ label="Результаты",
570
+ lines=10,
571
+ )
 
 
572
  zero_shot_button.click(
573
  fn=perform_zero_shot_classification,
574
  inputs=[zero_shot_input_image, zero_shot_classes_text, clip_model_selector],
 
578
  with gr.Tab("Поиск изображений"):
579
  gr.Markdown("## Поиск изображений")
580
  with gr.Row():
 
581
  retrieval_dir = gr.File(
582
  label="Загрузите папку с изображениями",
583
  file_count="directory",
 
598
  value="clip_large_patch14",
599
  info=(
600
  "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n"
601
+ "clip_base_patch32 - openai/clip-vit-base-patch32 (альтернатива)"
602
  ),
603
  )
604
  retrieval_button = gr.Button("Поиск")
605
+ retrieval_output_text = gr.Textbox(
606
+ label="Результат",
607
+ )
608
+ retrieval_output_image = gr.Image(
609
+ label="Наиболее подходящее изображение",
610
+ )
 
 
611
  retrieval_button.click(
612
  fn=retrieve_best_image,
613
  inputs=[retrieval_dir, retrieval_query_text, retrieval_clip_selector],
 
618
  gr.Markdown("### Задачи:")
619
  gr.Markdown(
620
  """
621
+ - Компьютерное зрение: детекция объектов, сегментация, оценка глубины, генерация описаний изображений
 
622
  - Мультимодальные задачи: вопросы к изображению, zero-shot классификация изображений, поиск по изображениям по текстовому запросу
623
+ """
624
  )
625
+
626
  return demo_block
627
 
628