Aleks Zhuravlev commited on
Commit
6d91b68
·
1 Parent(s): 58ef339

Add application file

Browse files
Files changed (2) hide show
  1. app.py +388 -0
  2. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import torch
4
+ import torchaudio
5
+ from transformers import (
6
+ pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq,
7
+ AutoImageProcessor, AutoModelForObjectDetection,
8
+ BlipForQuestionAnswering, BlipProcessor, CLIPModel, CLIPProcessor,
9
+ VitsModel, AutoTokenizer
10
+ )
11
+ from PIL import Image, ImageDraw
12
+ import requests
13
+ import numpy as np
14
+ import soundfile as sf
15
+ from gtts import gTTS
16
+ import tempfile
17
+ import os
18
+ from sentence_transformers import SentenceTransformer
19
+
20
+ # Инициализация моделей (ленивая загрузка)
21
+ models = {}
22
+
23
+ def load_audio_model(model_name):
24
+ if model_name not in models:
25
+ if model_name == "whisper":
26
+ models[model_name] = pipeline(
27
+ "automatic-speech-recognition",
28
+ model="openai/whisper-small"
29
+ )
30
+ elif model_name == "wav2vec2":
31
+ models[model_name] = pipeline(
32
+ "automatic-speech-recognition",
33
+ model="bond005/wav2vec2-large-ru-golos"
34
+ )
35
+ elif model_name == "audio_classifier":
36
+ models[model_name] = pipeline(
37
+ "audio-classification",
38
+ model="MIT/ast-finetuned-audioset-10-10-0.4593"
39
+ )
40
+ elif model_name == "emotion_classifier":
41
+ models[model_name] = pipeline(
42
+ "audio-classification",
43
+ model="superb/hubert-large-superb-er"
44
+ )
45
+ return models[model_name]
46
+
47
+ def load_image_model(model_name):
48
+ if model_name not in models:
49
+ if model_name == "object_detection":
50
+ models[model_name] = pipeline("object-detection", model="facebook/detr-resnet-50")
51
+ elif model_name == "segmentation":
52
+ models[model_name] = pipeline("image-segmentation", model="nvidia/segformer-b0-finetuned-ade-512-512")
53
+ elif model_name == "captioning":
54
+ models[model_name] = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
55
+ elif model_name == "vqa":
56
+ models[model_name] = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
57
+ elif model_name == "clip":
58
+ models[model_name] = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
59
+ models[f"{model_name}_processor"] = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
60
+ return models[model_name]
61
+
62
+ # Функции для обработки аудио
63
+ def audio_classification(audio_file, model_type):
64
+ classifier = load_audio_model(model_type)
65
+ results = classifier(audio_file)
66
+
67
+ output = "Топ-5 предсказаний:\n"
68
+ for i, result in enumerate(results[:5]):
69
+ output += f"{i+1}. {result['label']}: {result['score']:.4f}\n"
70
+
71
+ return output
72
+
73
+ def speech_recognition(audio_file, model_type):
74
+ asr_pipeline = load_audio_model(model_type)
75
+
76
+ if model_type == "whisper":
77
+ result = asr_pipeline(audio_file, generate_kwargs={"language": "russian"})
78
+ else:
79
+ result = asr_pipeline(audio_file)
80
+
81
+ return result['text']
82
+
83
+ def text_to_speech(text, model_type):
84
+ if model_type == "silero":
85
+ # Silero TTS
86
+ model, _ = torch.hub.load(repo_or_dir='snakers4/silero-models',
87
+ model='silero_tts',
88
+ language='ru',
89
+ speaker='ru_v3')
90
+
91
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
92
+ model.save_wav(text=text, speaker='aidar', sample_rate=48000, audio_path=f.name)
93
+ return f.name
94
+
95
+ elif model_type == "gtts":
96
+ # Google TTS
97
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
98
+ tts = gTTS(text=text, lang='ru')
99
+ tts.save(f.name)
100
+ return f.name
101
+
102
+ elif model_type == "mms":
103
+ # Facebook MMS TTS
104
+ model = VitsModel.from_pretrained("facebook/mms-tts-rus")
105
+ tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
106
+
107
+ inputs = tokenizer(text, return_tensors="pt")
108
+ with torch.no_grad():
109
+ output = model(**inputs).waveform
110
+
111
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
112
+ sf.write(f.name, output.numpy().squeeze(), model.config.sampling_rate)
113
+ return f.name
114
+
115
+ # Функции для обработки изображений
116
+ def object_detection(image):
117
+ detector = load_image_model("object_detection")
118
+ results = detector(image)
119
+
120
+ # Рисуем bounding boxes
121
+ draw = ImageDraw.Draw(image)
122
+ for result in results:
123
+ box = result['box']
124
+ label = result['label']
125
+ score = result['score']
126
+
127
+ draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']],
128
+ outline='red', width=3)
129
+ draw.text((box['xmin'], box['ymin']),
130
+ f"{label}: {score:.2f}", fill='red')
131
+
132
+ return image
133
+
134
+ def image_segmentation(image):
135
+ segmenter = load_image_model("segmentation")
136
+ results = segmenter(image)
137
+
138
+ # Возвращаем первую маску сегментации
139
+ return results[0]['mask']
140
+
141
+ def image_captioning(image):
142
+ captioner = load_image_model("captioning")
143
+ result = captioner(image)
144
+ return result[0]['generated_text']
145
+
146
+ def visual_question_answering(image, question):
147
+ vqa_pipeline = load_image_model("vqa")
148
+ result = vqa_pipeline(image, question)
149
+ return f"{result[0]['answer']} (confidence: {result[0]['score']:.3f})"
150
+
151
+ def zero_shot_classification(image, classes):
152
+ model = load_image_model("clip")
153
+ processor = models["clip_processor"]
154
+
155
+ class_list = [cls.strip() for cls in classes.split(",")]
156
+
157
+ inputs = processor(text=class_list, images=image, return_tensors="pt", padding=True)
158
+ with torch.no_grad():
159
+ outputs = model(**inputs)
160
+ logits_per_image = outputs.logits_per_image
161
+ probs = logits_per_image.softmax(dim=1)
162
+
163
+ result = "Zero-Shot Classification Results:\n"
164
+ for i, cls in enumerate(class_list):
165
+ result += f"{cls}: {probs[0][i].item():.4f}\n"
166
+
167
+ return result
168
+
169
+ def image_retrieval(images, query):
170
+ if not images or not query:
171
+ return "Пожалуйста, загрузите изображения и введите запрос"
172
+
173
+ # Используем CLIP для поиска
174
+ model = load_image_model("clip")
175
+ processor = models["clip_processor"]
176
+
177
+ # Обрабатываем все изображения
178
+ image_inputs = processor(images=images, return_tensors="pt", padding=True)
179
+ with torch.no_grad():
180
+ image_embeddings = model.get_image_features(**image_inputs)
181
+ image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
182
+
183
+ # Обрабатываем текстовый запрос
184
+ text_inputs = processor(text=[query], return_tensors="pt", padding=True)
185
+ with torch.no_grad():
186
+ text_embeddings = model.get_text_features(**text_inputs)
187
+ text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
188
+
189
+ # Вычисляем схожести
190
+ similarities = (image_embeddings @ text_embeddings.T)
191
+
192
+ # Находим лучшее изображение
193
+ best_idx = similarities.argmax().item()
194
+ best_score = similarities[best_idx].item()
195
+
196
+ return f"Лучшее изображение: #{best_idx + 1} (схожесть: {best_score:.4f})", images[best_idx]
197
+
198
+ # Создаем интерфейс Gradio
199
+ with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo:
200
+ gr.Markdown("# 🎯 Мультимодальные AI модели")
201
+ gr.Markdown("Демонстрация различных задач компьютерного зрения и обработки звука с использованием Hugging Face Transformers")
202
+
203
+ with gr.Tab("🎵 Классификация аудио"):
204
+ gr.Markdown("## Zero-Shot Audio Classification")
205
+ with gr.Row():
206
+ with gr.Column():
207
+ audio_input = gr.Audio(label="Загрузите аудиофайл", type="filepath")
208
+ audio_model_dropdown = gr.Dropdown(
209
+ choices=["audio_classifier", "emotion_classifier"],
210
+ label="Выберите модель",
211
+ value="audio_classifier",
212
+ info="audio_classifier - общая классификация, emotion_classifier - эмоции в речи"
213
+ )
214
+ classify_btn = gr.Button("Классифицировать")
215
+ with gr.Column():
216
+ audio_output = gr.Textbox(label="Результаты классификации", lines=10)
217
+
218
+ classify_btn.click(
219
+ fn=audio_classification,
220
+ inputs=[audio_input, audio_model_dropdown],
221
+ outputs=audio_output
222
+ )
223
+
224
+ with gr.Tab("🗣️ Распознавание речи"):
225
+ gr.Markdown("## Automatic Speech Recognition (ASR)")
226
+ with gr.Row():
227
+ with gr.Column():
228
+ asr_audio_input = gr.Audio(label="Загрузите аудио с речью", type="filepath")
229
+ asr_model_dropdown = gr.Dropdown(
230
+ choices=["whisper", "wav2vec2"],
231
+ label="Выберите модель",
232
+ value="whisper",
233
+ info="whisper - многоязычная, wav2vec2 - специализированная для русского"
234
+ )
235
+ transcribe_btn = gr.Button("Транскрибировать")
236
+ with gr.Column():
237
+ asr_output = gr.Textbox(label="Транскрипция", lines=5)
238
+
239
+ transcribe_btn.click(
240
+ fn=speech_recognition,
241
+ inputs=[asr_audio_input, asr_model_dropdown],
242
+ outputs=asr_output
243
+ )
244
+
245
+ with gr.Tab("🔊 Синтез речи"):
246
+ gr.Markdown("## Text-to-Speech (TTS)")
247
+ with gr.Row():
248
+ with gr.Column():
249
+ tts_text_input = gr.Textbox(
250
+ label="Введите текст для синтеза",
251
+ placeholder="Введите текст на русском языке...",
252
+ lines=3
253
+ )
254
+ tts_model_dropdown = gr.Dropdown(
255
+ choices=["silero", "gtts", "mms"],
256
+ label="Выберите модель",
257
+ value="silero",
258
+ info="silero - высокое качество, gtts - Google TTS, mms - Facebook MMS"
259
+ )
260
+ synthesize_btn = gr.Button("Синтезировать речь")
261
+ with gr.Column():
262
+ tts_output = gr.Audio(label="Синтезированная речь")
263
+
264
+ synthesize_btn.click(
265
+ fn=text_to_speech,
266
+ inputs=[tts_text_input, tts_model_dropdown],
267
+ outputs=tts_output
268
+ )
269
+
270
+ with gr.Tab("📦 Детекция объектов"):
271
+ gr.Markdown("## Object Detection")
272
+ with gr.Row():
273
+ with gr.Column():
274
+ obj_detection_input = gr.Image(label="Загрузите изображение", type="pil")
275
+ detect_btn = gr.Button("Обнаружить объекты")
276
+ with gr.Column():
277
+ obj_detection_output = gr.Image(label="Результат детекции")
278
+
279
+ detect_btn.click(
280
+ fn=object_detection,
281
+ inputs=obj_detection_input,
282
+ outputs=obj_detection_output
283
+ )
284
+
285
+ with gr.Tab("🎨 Сегментация"):
286
+ gr.Markdown("## Image Segmentation")
287
+ with gr.Row():
288
+ with gr.Column():
289
+ seg_input = gr.Image(label="Загрузите изображение", type="pil")
290
+ segment_btn = gr.Button("Сегментировать")
291
+ with gr.Column():
292
+ seg_output = gr.Image(label="Маска сегментации")
293
+
294
+ segment_btn.click(
295
+ fn=image_segmentation,
296
+ inputs=seg_input,
297
+ outputs=seg_output
298
+ )
299
+
300
+ with gr.Tab("📝 Описание изображений"):
301
+ gr.Markdown("## Image Captioning")
302
+ with gr.Row():
303
+ with gr.Column():
304
+ caption_input = gr.Image(label="Загрузите изображение", type="pil")
305
+ caption_btn = gr.Button("Сгенерировать описание")
306
+ with gr.Column():
307
+ caption_output = gr.Textbox(label="Описание изображения", lines=3)
308
+
309
+ caption_btn.click(
310
+ fn=image_captioning,
311
+ inputs=caption_input,
312
+ outputs=caption_output
313
+ )
314
+
315
+ with gr.Tab("❓ Визуальные вопросы"):
316
+ gr.Markdown("## Visual Question Answering")
317
+ with gr.Row():
318
+ with gr.Column():
319
+ vqa_image_input = gr.Image(label="Загрузите изображение", type="pil")
320
+ vqa_question_input = gr.Textbox(
321
+ label="Вопрос об изображении",
322
+ placeholder="Что происходит на этом изображении?",
323
+ lines=2
324
+ )
325
+ vqa_btn = gr.Button("Ответить на вопрос")
326
+ with gr.Column():
327
+ vqa_output = gr.Textbox(label="Ответ", lines=3)
328
+
329
+ vqa_btn.click(
330
+ fn=visual_question_answering,
331
+ inputs=[vqa_image_input, vqa_question_input],
332
+ outputs=vqa_output
333
+ )
334
+
335
+ with gr.Tab("🎯 Zero-Shot классификация"):
336
+ gr.Markdown("## Zero-Shot Image Classification")
337
+ with gr.Row():
338
+ with gr.Column():
339
+ zs_image_input = gr.Image(label="Загрузите изображение", type="pil")
340
+ zs_classes_input = gr.Textbox(
341
+ label="Классы для классификации (через запятую)",
342
+ placeholder="человек, машина, дерево, здание, животное",
343
+ lines=2
344
+ )
345
+ zs_classify_btn = gr.Button("Классифицировать")
346
+ with gr.Column():
347
+ zs_output = gr.Textbox(label="Результаты классификации", lines=10)
348
+
349
+ zs_classify_btn.click(
350
+ fn=zero_shot_classification,
351
+ inputs=[zs_image_input, zs_classes_input],
352
+ outputs=zs_output
353
+ )
354
+
355
+ with gr.Tab("🔍 Поиск изображений"):
356
+ gr.Markdown("## Image Retrieval")
357
+ with gr.Row():
358
+ with gr.Column():
359
+ retrieval_images_input = gr.Gallery(
360
+ label="Загрузите изображения для поиска",
361
+ type="pil"
362
+ )
363
+ retrieval_query_input = gr.Textbox(
364
+ label="Текстовый запрос",
365
+ placeholder="описание того, что вы ищете...",
366
+ lines=2
367
+ )
368
+ retrieval_btn = gr.Button("Найти изображение")
369
+ with gr.Column():
370
+ retrieval_output_text = gr.Textbox(label="Результат поиска")
371
+ retrieval_output_image = gr.Image(label="Найденное изображение")
372
+
373
+ retrieval_btn.click(
374
+ fn=image_retrieval,
375
+ inputs=[retrieval_images_input, retrieval_query_input],
376
+ outputs=[retrieval_output_text, retrieval_output_image]
377
+ )
378
+
379
+ gr.Markdown("---")
380
+ gr.Markdown("### 📊 Поддерживаемые задачи:")
381
+ gr.Markdown("""
382
+ - **🎵 Аудио**: Классификация, распознавание речи, синтез речи
383
+ - **👁️ Компьютерное зрение**: Детекция объектов, сегментация, описание изображений
384
+ - **🤖 Мультимодальные**: Визуальные вопросы, zero-shot классификация, поиск по изображениям
385
+ """)
386
+
387
+ if __name__ == "__main__":
388
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchaudio>=2.0.0
3
+ transformers>=4.30.0
4
+ gradio>=4.0.0
5
+ pillow>=9.0.0
6
+ numpy>=1.21.0
7
+ soundfile>=0.12.0
8
+ gtts>=2.3.0
9
+ sentence-transformers>=2.2.0
10
+ librosa>=0.10.0
11
+ requests>=2.28.0
12
+ accelerate>=0.20.0