|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torchaudio |
|
|
from transformers import ( |
|
|
pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq, |
|
|
AutoImageProcessor, AutoModelForObjectDetection, |
|
|
BlipForQuestionAnswering, BlipProcessor, CLIPModel, CLIPProcessor, |
|
|
VitsModel, AutoTokenizer |
|
|
) |
|
|
from PIL import Image, ImageDraw |
|
|
import requests |
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
from gtts import gTTS |
|
|
import tempfile |
|
|
import os |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
models = {} |
|
|
|
|
|
|
|
|
def load_image_model(model_name): |
|
|
if model_name not in models: |
|
|
if model_name == "object_detection": |
|
|
models[model_name] = pipeline("object-detection", model="facebook/detr-resnet-50") |
|
|
elif model_name == "segmentation": |
|
|
models[model_name] = pipeline("image-segmentation", model="nvidia/segformer-b0-finetuned-ade-512-512") |
|
|
elif model_name == "captioning": |
|
|
models[model_name] = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") |
|
|
elif model_name == "vqa": |
|
|
models[model_name] = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa") |
|
|
elif model_name == "clip": |
|
|
models[model_name] = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
|
models[f"{model_name}_processor"] = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
return models[model_name] |
|
|
|
|
|
|
|
|
|
|
|
def object_detection(image): |
|
|
detector = load_image_model("object_detection") |
|
|
results = detector(image) |
|
|
|
|
|
|
|
|
draw = ImageDraw.Draw(image) |
|
|
for result in results: |
|
|
box = result['box'] |
|
|
label = result['label'] |
|
|
score = result['score'] |
|
|
|
|
|
draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], |
|
|
outline='red', width=3) |
|
|
draw.text((box['xmin'], box['ymin']), |
|
|
f"{label}: {score:.2f}", fill='red') |
|
|
|
|
|
return image |
|
|
|
|
|
def image_segmentation(image): |
|
|
segmenter = load_image_model("segmentation") |
|
|
results = segmenter(image) |
|
|
|
|
|
|
|
|
return results[0]['mask'] |
|
|
|
|
|
def image_captioning(image): |
|
|
captioner = load_image_model("captioning") |
|
|
result = captioner(image) |
|
|
return result[0]['generated_text'] |
|
|
|
|
|
def visual_question_answering(image, question): |
|
|
vqa_pipeline = load_image_model("vqa") |
|
|
result = vqa_pipeline(image, question) |
|
|
return f"{result[0]['answer']} (confidence: {result[0]['score']:.3f})" |
|
|
|
|
|
def zero_shot_classification(image, classes): |
|
|
model = load_image_model("clip") |
|
|
processor = models["clip_processor"] |
|
|
|
|
|
class_list = [cls.strip() for cls in classes.split(",")] |
|
|
|
|
|
inputs = processor(text=class_list, images=image, return_tensors="pt", padding=True) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits_per_image = outputs.logits_per_image |
|
|
probs = logits_per_image.softmax(dim=1) |
|
|
|
|
|
result = "Zero-Shot Classification Results:\n" |
|
|
for i, cls in enumerate(class_list): |
|
|
result += f"{cls}: {probs[0][i].item():.4f}\n" |
|
|
|
|
|
return result |
|
|
|
|
|
def image_retrieval(images, query): |
|
|
if not images or not query: |
|
|
return "Пожалуйста, загрузите изображения и введите запрос" |
|
|
|
|
|
|
|
|
model = load_image_model("clip") |
|
|
processor = models["clip_processor"] |
|
|
|
|
|
|
|
|
image_inputs = processor(images=images, return_tensors="pt", padding=True) |
|
|
with torch.no_grad(): |
|
|
image_embeddings = model.get_image_features(**image_inputs) |
|
|
image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
text_inputs = processor(text=[query], return_tensors="pt", padding=True) |
|
|
with torch.no_grad(): |
|
|
text_embeddings = model.get_text_features(**text_inputs) |
|
|
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
similarities = (image_embeddings @ text_embeddings.T) |
|
|
|
|
|
|
|
|
best_idx = similarities.argmax().item() |
|
|
best_score = similarities[best_idx].item() |
|
|
|
|
|
return f"Лучшее изображение: #{best_idx + 1} (схожесть: {best_score:.4f})", images[best_idx] |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# Нестеров Владимир ") |
|
|
gr.Markdown("Демонстрация различных задач компьютерного зрения с использованием Hugging Face Transformers") |
|
|
|
|
|
|
|
|
with gr.Tab("📦 Детекция объектов"): |
|
|
gr.Markdown("## Object Detection") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
obj_detection_input = gr.Image(label="Загрузите изображение", type="pil") |
|
|
detect_btn = gr.Button("Обнаружить объекты") |
|
|
with gr.Column(): |
|
|
obj_detection_output = gr.Image(label="Результат детекции") |
|
|
|
|
|
detect_btn.click( |
|
|
fn=object_detection, |
|
|
inputs=obj_detection_input, |
|
|
outputs=obj_detection_output |
|
|
) |
|
|
|
|
|
with gr.Tab("🎨 Сегментация"): |
|
|
gr.Markdown("## Image Segmentation") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
seg_input = gr.Image(label="Загрузите изображение", type="pil") |
|
|
segment_btn = gr.Button("Сегментировать") |
|
|
with gr.Column(): |
|
|
seg_output = gr.Image(label="Маска сегментации") |
|
|
|
|
|
segment_btn.click( |
|
|
fn=image_segmentation, |
|
|
inputs=seg_input, |
|
|
outputs=seg_output |
|
|
) |
|
|
|
|
|
with gr.Tab("📝 Описание изображений"): |
|
|
gr.Markdown("## Image Captioning") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
caption_input = gr.Image(label="Загрузите изображение", type="pil") |
|
|
caption_btn = gr.Button("Сгенерировать описание") |
|
|
with gr.Column(): |
|
|
caption_output = gr.Textbox(label="Описание изображения", lines=3) |
|
|
|
|
|
caption_btn.click( |
|
|
fn=image_captioning, |
|
|
inputs=caption_input, |
|
|
outputs=caption_output |
|
|
) |
|
|
|
|
|
with gr.Tab("❓ Визуальные вопросы"): |
|
|
gr.Markdown("## Visual Question Answering") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
vqa_image_input = gr.Image(label="Загрузите изображение", type="pil") |
|
|
vqa_question_input = gr.Textbox( |
|
|
label="Вопрос об изображении", |
|
|
placeholder="Что происходит на этом изображении?", |
|
|
lines=2 |
|
|
) |
|
|
vqa_btn = gr.Button("Ответить на вопрос") |
|
|
with gr.Column(): |
|
|
vqa_output = gr.Textbox(label="Ответ", lines=3) |
|
|
|
|
|
vqa_btn.click( |
|
|
fn=visual_question_answering, |
|
|
inputs=[vqa_image_input, vqa_question_input], |
|
|
outputs=vqa_output |
|
|
) |
|
|
|
|
|
with gr.Tab("🎯 Zero-Shot классификация"): |
|
|
gr.Markdown("## Zero-Shot Image Classification") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
zs_image_input = gr.Image(label="Загрузите изображение", type="pil") |
|
|
zs_classes_input = gr.Textbox( |
|
|
label="Классы для классификации (через запятую)", |
|
|
placeholder="человек, машина, дерево, здание, животное", |
|
|
lines=2 |
|
|
) |
|
|
zs_classify_btn = gr.Button("Классифицировать") |
|
|
with gr.Column(): |
|
|
zs_output = gr.Textbox(label="Результаты классификации", lines=10) |
|
|
|
|
|
zs_classify_btn.click( |
|
|
fn=zero_shot_classification, |
|
|
inputs=[zs_image_input, zs_classes_input], |
|
|
outputs=zs_output |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### 📊 Поддерживаемые задачи:") |
|
|
gr.Markdown(""" |
|
|
- **👁️ Компьютерное зрение**: Детекция объектов, сегментация, описание изображений |
|
|
- **🤖 Мультимодальные**: Визуальные вопросы, zero-shot классификация |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=True) |