421_space / app.py
mi55th's picture
Update app.py
7edcdca verified
# app.py
import gradio as gr
import torch
import torchaudio
from transformers import (
pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq,
AutoImageProcessor, AutoModelForObjectDetection,
BlipForQuestionAnswering, BlipProcessor, CLIPModel, CLIPProcessor,
VitsModel, AutoTokenizer
)
from PIL import Image, ImageDraw
import requests
import numpy as np
import soundfile as sf
from gtts import gTTS
import tempfile
import os
from sentence_transformers import SentenceTransformer
# Инициализация моделей (ленивая загрузка)
models = {}
def load_image_model(model_name):
if model_name not in models:
if model_name == "object_detection":
models[model_name] = pipeline("object-detection", model="facebook/detr-resnet-50")
elif model_name == "segmentation":
models[model_name] = pipeline("image-segmentation", model="nvidia/segformer-b0-finetuned-ade-512-512")
elif model_name == "captioning":
models[model_name] = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
elif model_name == "vqa":
models[model_name] = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
elif model_name == "clip":
models[model_name] = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
models[f"{model_name}_processor"] = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
return models[model_name]
# Функции для обработки изображений
def object_detection(image):
detector = load_image_model("object_detection")
results = detector(image)
# Рисуем bounding boxes
draw = ImageDraw.Draw(image)
for result in results:
box = result['box']
label = result['label']
score = result['score']
draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']],
outline='red', width=3)
draw.text((box['xmin'], box['ymin']),
f"{label}: {score:.2f}", fill='red')
return image
def image_segmentation(image):
segmenter = load_image_model("segmentation")
results = segmenter(image)
# Возвращаем первую маску сегментации
return results[0]['mask']
def image_captioning(image):
captioner = load_image_model("captioning")
result = captioner(image)
return result[0]['generated_text']
def visual_question_answering(image, question):
vqa_pipeline = load_image_model("vqa")
result = vqa_pipeline(image, question)
return f"{result[0]['answer']} (confidence: {result[0]['score']:.3f})"
def zero_shot_classification(image, classes):
model = load_image_model("clip")
processor = models["clip_processor"]
class_list = [cls.strip() for cls in classes.split(",")]
inputs = processor(text=class_list, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
result = "Zero-Shot Classification Results:\n"
for i, cls in enumerate(class_list):
result += f"{cls}: {probs[0][i].item():.4f}\n"
return result
def image_retrieval(images, query):
if not images or not query:
return "Пожалуйста, загрузите изображения и введите запрос"
# Используем CLIP для поиска
model = load_image_model("clip")
processor = models["clip_processor"]
# Обрабатываем все изображения
image_inputs = processor(images=images, return_tensors="pt", padding=True)
with torch.no_grad():
image_embeddings = model.get_image_features(**image_inputs)
image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
# Обрабатываем текстовый запрос
text_inputs = processor(text=[query], return_tensors="pt", padding=True)
with torch.no_grad():
text_embeddings = model.get_text_features(**text_inputs)
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
# Вычисляем схожести
similarities = (image_embeddings @ text_embeddings.T)
# Находим лучшее изображение
best_idx = similarities.argmax().item()
best_score = similarities[best_idx].item()
return f"Лучшее изображение: #{best_idx + 1} (схожесть: {best_score:.4f})", images[best_idx]
# Создаем интерфейс Gradio
with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Нестеров Владимир ")
gr.Markdown("Демонстрация различных задач компьютерного зрения с использованием Hugging Face Transformers")
with gr.Tab("📦 Детекция объектов"):
gr.Markdown("## Object Detection")
with gr.Row():
with gr.Column():
obj_detection_input = gr.Image(label="Загрузите изображение", type="pil")
detect_btn = gr.Button("Обнаружить объекты")
with gr.Column():
obj_detection_output = gr.Image(label="Результат детекции")
detect_btn.click(
fn=object_detection,
inputs=obj_detection_input,
outputs=obj_detection_output
)
with gr.Tab("🎨 Сегментация"):
gr.Markdown("## Image Segmentation")
with gr.Row():
with gr.Column():
seg_input = gr.Image(label="Загрузите изображение", type="pil")
segment_btn = gr.Button("Сегментировать")
with gr.Column():
seg_output = gr.Image(label="Маска сегментации")
segment_btn.click(
fn=image_segmentation,
inputs=seg_input,
outputs=seg_output
)
with gr.Tab("📝 Описание изображений"):
gr.Markdown("## Image Captioning")
with gr.Row():
with gr.Column():
caption_input = gr.Image(label="Загрузите изображение", type="pil")
caption_btn = gr.Button("Сгенерировать описание")
with gr.Column():
caption_output = gr.Textbox(label="Описание изображения", lines=3)
caption_btn.click(
fn=image_captioning,
inputs=caption_input,
outputs=caption_output
)
with gr.Tab("❓ Визуальные вопросы"):
gr.Markdown("## Visual Question Answering")
with gr.Row():
with gr.Column():
vqa_image_input = gr.Image(label="Загрузите изображение", type="pil")
vqa_question_input = gr.Textbox(
label="Вопрос об изображении",
placeholder="Что происходит на этом изображении?",
lines=2
)
vqa_btn = gr.Button("Ответить на вопрос")
with gr.Column():
vqa_output = gr.Textbox(label="Ответ", lines=3)
vqa_btn.click(
fn=visual_question_answering,
inputs=[vqa_image_input, vqa_question_input],
outputs=vqa_output
)
with gr.Tab("🎯 Zero-Shot классификация"):
gr.Markdown("## Zero-Shot Image Classification")
with gr.Row():
with gr.Column():
zs_image_input = gr.Image(label="Загрузите изображение", type="pil")
zs_classes_input = gr.Textbox(
label="Классы для классификации (через запятую)",
placeholder="человек, машина, дерево, здание, животное",
lines=2
)
zs_classify_btn = gr.Button("Классифицировать")
with gr.Column():
zs_output = gr.Textbox(label="Результаты классификации", lines=10)
zs_classify_btn.click(
fn=zero_shot_classification,
inputs=[zs_image_input, zs_classes_input],
outputs=zs_output
)
gr.Markdown("---")
gr.Markdown("### 📊 Поддерживаемые задачи:")
gr.Markdown("""
- **👁️ Компьютерное зрение**: Детекция объектов, сегментация, описание изображений
- **🤖 Мультимодальные**: Визуальные вопросы, zero-shot классификация
""")
if __name__ == "__main__":
demo.launch(share=True)