|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torchaudio |
|
|
from transformers import ( |
|
|
pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq, |
|
|
AutoImageProcessor, AutoModelForObjectDetection, |
|
|
BlipForQuestionAnswering, BlipProcessor, CLIPModel, CLIPProcessor, |
|
|
VitsModel, AutoTokenizer |
|
|
) |
|
|
from PIL import Image, ImageDraw |
|
|
import requests |
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
from gtts import gTTS |
|
|
import tempfile |
|
|
import os |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
models = {} |
|
|
|
|
|
def load_audio_model(model_name): |
|
|
if model_name not in models: |
|
|
if model_name == "whisper": |
|
|
models[model_name] = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model="openai/whisper-small" |
|
|
) |
|
|
elif model_name == "wav2vec2": |
|
|
models[model_name] = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model="bond005/wav2vec2-large-ru-golos" |
|
|
) |
|
|
elif model_name == "audio_classifier": |
|
|
models[model_name] = pipeline( |
|
|
"audio-classification", |
|
|
model="MIT/ast-finetuned-audioset-10-10-0.4593" |
|
|
) |
|
|
elif model_name == "emotion_classifier": |
|
|
models[model_name] = pipeline( |
|
|
"audio-classification", |
|
|
model="superb/hubert-large-superb-er" |
|
|
) |
|
|
return models[model_name] |
|
|
|
|
|
def load_image_model(model_name): |
|
|
if model_name not in models: |
|
|
if model_name == "object_detection": |
|
|
models[model_name] = pipeline("object-detection", model="facebook/detr-resnet-50") |
|
|
elif model_name == "segmentation": |
|
|
models[model_name] = pipeline("image-segmentation", model="nvidia/segformer-b0-finetuned-ade-512-512") |
|
|
elif model_name == "captioning": |
|
|
models[model_name] = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") |
|
|
elif model_name == "vqa": |
|
|
models[model_name] = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa") |
|
|
elif model_name == "clip": |
|
|
models[model_name] = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
|
models[f"{model_name}_processor"] = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
return models[model_name] |
|
|
|
|
|
|
|
|
def audio_classification(audio_file, model_type): |
|
|
classifier = load_audio_model(model_type) |
|
|
results = classifier(audio_file) |
|
|
|
|
|
output = "Топ-5 предсказаний:\n" |
|
|
for i, result in enumerate(results[:5]): |
|
|
output += f"{i+1}. {result['label']}: {result['score']:.4f}\n" |
|
|
|
|
|
return output |
|
|
|
|
|
def speech_recognition(audio_file, model_type): |
|
|
asr_pipeline = load_audio_model(model_type) |
|
|
|
|
|
if model_type == "whisper": |
|
|
result = asr_pipeline(audio_file, generate_kwargs={"language": "russian"}) |
|
|
else: |
|
|
result = asr_pipeline(audio_file) |
|
|
|
|
|
return result['text'] |
|
|
|
|
|
def text_to_speech(text, model_type): |
|
|
if model_type == "silero": |
|
|
|
|
|
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-models', |
|
|
model='silero_tts', |
|
|
language='ru', |
|
|
speaker='ru_v3') |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
|
|
model.save_wav(text=text, speaker='aidar', sample_rate=48000, audio_path=f.name) |
|
|
return f.name |
|
|
|
|
|
elif model_type == "gtts": |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
|
|
tts = gTTS(text=text, lang='ru') |
|
|
tts.save(f.name) |
|
|
return f.name |
|
|
|
|
|
elif model_type == "mms": |
|
|
|
|
|
model = VitsModel.from_pretrained("facebook/mms-tts-rus") |
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus") |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
output = model(**inputs).waveform |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
|
|
sf.write(f.name, output.numpy().squeeze(), model.config.sampling_rate) |
|
|
return f.name |
|
|
|
|
|
|
|
|
def object_detection(image): |
|
|
detector = load_image_model("object_detection") |
|
|
results = detector(image) |
|
|
|
|
|
|
|
|
draw = ImageDraw.Draw(image) |
|
|
for result in results: |
|
|
box = result['box'] |
|
|
label = result['label'] |
|
|
score = result['score'] |
|
|
|
|
|
draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], |
|
|
outline='red', width=3) |
|
|
draw.text((box['xmin'], box['ymin']), |
|
|
f"{label}: {score:.2f}", fill='red') |
|
|
|
|
|
return image |
|
|
|
|
|
def image_segmentation(image): |
|
|
segmenter = load_image_model("segmentation") |
|
|
results = segmenter(image) |
|
|
|
|
|
|
|
|
return results[0]['mask'] |
|
|
|
|
|
def image_captioning(image): |
|
|
captioner = load_image_model("captioning") |
|
|
result = captioner(image) |
|
|
return result[0]['generated_text'] |
|
|
|
|
|
def visual_question_answering(image, question): |
|
|
vqa_pipeline = load_image_model("vqa") |
|
|
result = vqa_pipeline(image, question) |
|
|
return f"{result[0]['answer']} (confidence: {result[0]['score']:.3f})" |
|
|
|
|
|
def zero_shot_classification(image, classes): |
|
|
model = load_image_model("clip") |
|
|
processor = models["clip_processor"] |
|
|
|
|
|
class_list = [cls.strip() for cls in classes.split(",")] |
|
|
|
|
|
inputs = processor(text=class_list, images=image, return_tensors="pt", padding=True) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits_per_image = outputs.logits_per_image |
|
|
probs = logits_per_image.softmax(dim=1) |
|
|
|
|
|
result = "Zero-Shot Classification Results:\n" |
|
|
for i, cls in enumerate(class_list): |
|
|
result += f"{cls}: {probs[0][i].item():.4f}\n" |
|
|
|
|
|
return result |
|
|
|
|
|
def image_retrieval(images, query): |
|
|
if not images or not query: |
|
|
return "Пожалуйста, загрузите изображения и введите запрос" |
|
|
|
|
|
|
|
|
model = load_image_model("clip") |
|
|
processor = models["clip_processor"] |
|
|
|
|
|
|
|
|
image_inputs = processor(images=images, return_tensors="pt", padding=True) |
|
|
with torch.no_grad(): |
|
|
image_embeddings = model.get_image_features(**image_inputs) |
|
|
image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
text_inputs = processor(text=[query], return_tensors="pt", padding=True) |
|
|
with torch.no_grad(): |
|
|
text_embeddings = model.get_text_features(**text_inputs) |
|
|
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
similarities = (image_embeddings @ text_embeddings.T) |
|
|
|
|
|
|
|
|
best_idx = similarities.argmax().item() |
|
|
best_score = similarities[best_idx].item() |
|
|
|
|
|
return f"Лучшее изображение: #{best_idx + 1} (схожесть: {best_score:.4f})", images[best_idx] |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# 🎯 Мультимодальные AI модели") |
|
|
gr.Markdown("Демонстрация различных задач компьютерного зрения и обработки звука с использованием Hugging Face Transformers") |
|
|
|
|
|
with gr.Tab("🎵 Классификация аудио"): |
|
|
gr.Markdown("## Zero-Shot Audio Classification") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
audio_input = gr.Audio(label="Загрузите аудиофайл", type="filepath") |
|
|
audio_model_dropdown = gr.Dropdown( |
|
|
choices=["audio_classifier", "emotion_classifier"], |
|
|
label="Выберите модель", |
|
|
value="audio_classifier", |
|
|
info="audio_classifier - общая классификация, emotion_classifier - эмоции в речи" |
|
|
) |
|
|
classify_btn = gr.Button("Классифицировать") |
|
|
with gr.Column(): |
|
|
audio_output = gr.Textbox(label="Результаты классификации", lines=10) |
|
|
|
|
|
classify_btn.click( |
|
|
fn=audio_classification, |
|
|
inputs=[audio_input, audio_model_dropdown], |
|
|
outputs=audio_output |
|
|
) |
|
|
|
|
|
with gr.Tab("🗣️ Распознавание речи"): |
|
|
gr.Markdown("## Automatic Speech Recognition (ASR)") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
asr_audio_input = gr.Audio(label="Загрузите аудио с речью", type="filepath") |
|
|
asr_model_dropdown = gr.Dropdown( |
|
|
choices=["whisper", "wav2vec2"], |
|
|
label="Выберите модель", |
|
|
value="whisper", |
|
|
info="whisper - многоязычная, wav2vec2 - специализированная для русского" |
|
|
) |
|
|
transcribe_btn = gr.Button("Транскрибировать") |
|
|
with gr.Column(): |
|
|
asr_output = gr.Textbox(label="Транскрипция", lines=5) |
|
|
|
|
|
transcribe_btn.click( |
|
|
fn=speech_recognition, |
|
|
inputs=[asr_audio_input, asr_model_dropdown], |
|
|
outputs=asr_output |
|
|
) |
|
|
|
|
|
with gr.Tab("🔊 Синтез речи"): |
|
|
gr.Markdown("## Text-to-Speech (TTS)") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
tts_text_input = gr.Textbox( |
|
|
label="Введите текст для синтеза", |
|
|
placeholder="Введите текст на русском языке...", |
|
|
lines=3 |
|
|
) |
|
|
tts_model_dropdown = gr.Dropdown( |
|
|
choices=["silero", "gtts", "mms"], |
|
|
label="Выберите модель", |
|
|
value="silero", |
|
|
info="silero - высокое качество, gtts - Google TTS, mms - Facebook MMS" |
|
|
) |
|
|
synthesize_btn = gr.Button("Синтезировать речь") |
|
|
with gr.Column(): |
|
|
tts_output = gr.Audio(label="Синтезированная речь") |
|
|
|
|
|
synthesize_btn.click( |
|
|
fn=text_to_speech, |
|
|
inputs=[tts_text_input, tts_model_dropdown], |
|
|
outputs=tts_output |
|
|
) |
|
|
|
|
|
with gr.Tab("📦 Детекция объектов"): |
|
|
gr.Markdown("## Object Detection") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
obj_detection_input = gr.Image(label="Загрузите изображение", type="pil") |
|
|
detect_btn = gr.Button("Обнаружить объекты") |
|
|
with gr.Column(): |
|
|
obj_detection_output = gr.Image(label="Результат детекции") |
|
|
|
|
|
detect_btn.click( |
|
|
fn=object_detection, |
|
|
inputs=obj_detection_input, |
|
|
outputs=obj_detection_output |
|
|
) |
|
|
|
|
|
with gr.Tab("🎨 Сегментация"): |
|
|
gr.Markdown("## Image Segmentation") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
seg_input = gr.Image(label="Загрузите изображение", type="pil") |
|
|
segment_btn = gr.Button("Сегментировать") |
|
|
with gr.Column(): |
|
|
seg_output = gr.Image(label="Маска сегментации") |
|
|
|
|
|
segment_btn.click( |
|
|
fn=image_segmentation, |
|
|
inputs=seg_input, |
|
|
outputs=seg_output |
|
|
) |
|
|
|
|
|
with gr.Tab("📝 Описание изображений"): |
|
|
gr.Markdown("## Image Captioning") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
caption_input = gr.Image(label="Загрузите изображение", type="pil") |
|
|
caption_btn = gr.Button("Сгенерировать описание") |
|
|
with gr.Column(): |
|
|
caption_output = gr.Textbox(label="Описание изображения", lines=3) |
|
|
|
|
|
caption_btn.click( |
|
|
fn=image_captioning, |
|
|
inputs=caption_input, |
|
|
outputs=caption_output |
|
|
) |
|
|
|
|
|
with gr.Tab("❓ Визуальные вопросы"): |
|
|
gr.Markdown("## Visual Question Answering") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
vqa_image_input = gr.Image(label="Загрузите изображение", type="pil") |
|
|
vqa_question_input = gr.Textbox( |
|
|
label="Вопрос об изображении", |
|
|
placeholder="Что происходит на этом изображении?", |
|
|
lines=2 |
|
|
) |
|
|
vqa_btn = gr.Button("Ответить на вопрос") |
|
|
with gr.Column(): |
|
|
vqa_output = gr.Textbox(label="Ответ", lines=3) |
|
|
|
|
|
vqa_btn.click( |
|
|
fn=visual_question_answering, |
|
|
inputs=[vqa_image_input, vqa_question_input], |
|
|
outputs=vqa_output |
|
|
) |
|
|
|
|
|
with gr.Tab("🎯 Zero-Shot классификация"): |
|
|
gr.Markdown("## Zero-Shot Image Classification") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
zs_image_input = gr.Image(label="Загрузите изображение", type="pil") |
|
|
zs_classes_input = gr.Textbox( |
|
|
label="Классы для классификации (через запятую)", |
|
|
placeholder="человек, машина, дерево, здание, животное", |
|
|
lines=2 |
|
|
) |
|
|
zs_classify_btn = gr.Button("Классифицировать") |
|
|
with gr.Column(): |
|
|
zs_output = gr.Textbox(label="Результаты классификации", lines=10) |
|
|
|
|
|
zs_classify_btn.click( |
|
|
fn=zero_shot_classification, |
|
|
inputs=[zs_image_input, zs_classes_input], |
|
|
outputs=zs_output |
|
|
) |
|
|
|
|
|
with gr.Tab("🔍 Поиск изображений"): |
|
|
gr.Markdown("## Image Retrieval") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
retrieval_images_input = gr.Gallery( |
|
|
label="Загрузите изображения для поиска", |
|
|
type="pil" |
|
|
) |
|
|
retrieval_query_input = gr.Textbox( |
|
|
label="Текстовый запрос", |
|
|
placeholder="описание того, что вы ищете...", |
|
|
lines=2 |
|
|
) |
|
|
retrieval_btn = gr.Button("Найти изображение") |
|
|
with gr.Column(): |
|
|
retrieval_output_text = gr.Textbox(label="Результат поиска") |
|
|
retrieval_output_image = gr.Image(label="Найденное изображение") |
|
|
|
|
|
retrieval_btn.click( |
|
|
fn=image_retrieval, |
|
|
inputs=[retrieval_images_input, retrieval_query_input], |
|
|
outputs=[retrieval_output_text, retrieval_output_image] |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### 📊 Поддерживаемые задачи:") |
|
|
gr.Markdown(""" |
|
|
- **🎵 Аудио**: Классификация, распознавание речи, синтез речи |
|
|
- **👁️ Компьютерное зрение**: Детекция объектов, сегментация, описание изображений |
|
|
- **🤖 Мультимодальные**: Визуальные вопросы, zero-shot классификация, поиск по изображениям |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=True) |