import tempfile from typing import List, Tuple, Any import gradio as gr import soundfile as sf import torch import torch.nn.functional as torch_functional from gtts import gTTS from PIL import Image, ImageDraw from transformers import ( AutoTokenizer, CLIPModel, CLIPProcessor, SamModel, SamProcessor, VitsModel, pipeline, BlipForQuestionAnswering, BlipProcessor, ) MODEL_STORE = {} def _normalize_gallery_images(gallery_value: Any) -> List[Image.Image]: if not gallery_value: return [] normalized_images: List[Image.Image] = [] for item in gallery_value: if isinstance(item, Image.Image): normalized_images.append(item) continue if isinstance(item, str): try: image_object = Image.open(item).convert("RGB") normalized_images.append(image_object) except Exception: continue continue if isinstance(item, (list, tuple)) and item: candidate = item[0] if isinstance(candidate, Image.Image): normalized_images.append(candidate) continue if isinstance(item, dict): candidate = item.get("image") or item.get("value") if isinstance(candidate, Image.Image): normalized_images.append(candidate) continue return normalized_images def get_audio_pipeline(model_key: str): if model_key in MODEL_STORE: return MODEL_STORE[model_key] if model_key == "whisper": audio_pipeline = pipeline( task="automatic-speech-recognition", model="distil-whisper/distil-small.en", ) elif model_key == "wav2vec2": audio_pipeline = pipeline( task="automatic-speech-recognition", model="openai/whisper-small", ) elif model_key == "audio_classifier": audio_pipeline = pipeline( task="audio-classification", model="MIT/ast-finetuned-audioset-10-10-0.4593", ) elif model_key == "emotion_classifier": audio_pipeline = pipeline( task="audio-classification", model="superb/hubert-large-superb-er", ) else: raise ValueError(f"Неизвестный тип аудио модели: {model_key}") MODEL_STORE[model_key] = audio_pipeline return audio_pipeline def get_zero_shot_audio_pipeline(): if "audio_zero_shot_clap" not in MODEL_STORE: zero_shot_pipeline = pipeline( task="zero-shot-audio-classification", model="laion/clap-htsat-unfused", ) MODEL_STORE["audio_zero_shot_clap"] = zero_shot_pipeline return MODEL_STORE["audio_zero_shot_clap"] def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]: if "blip_vqa_model" not in MODEL_STORE or "blip_vqa_processor" not in MODEL_STORE: blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") MODEL_STORE["blip_vqa_model"] = blip_model MODEL_STORE["blip_vqa_processor"] = blip_processor blip_model = MODEL_STORE["blip_vqa_model"] blip_processor = MODEL_STORE["blip_vqa_processor"] return blip_model, blip_processor def get_vision_pipeline(model_key: str): if model_key in MODEL_STORE: return MODEL_STORE[model_key] if model_key == "object_detection_conditional_detr": vision_pipeline = pipeline( task="object-detection", model="microsoft/conditional-detr-resnet-50", ) elif model_key == "object_detection_yolos_small": vision_pipeline = pipeline( task="object-detection", model="hustvl/yolos-small", ) elif model_key == "segmentation": vision_pipeline = pipeline( task="image-segmentation", model="nvidia/segformer-b0-finetuned-ade-512-512", ) elif model_key == "depth_estimation": vision_pipeline = pipeline( task="depth-estimation", model="Intel/dpt-hybrid-midas", ) elif model_key == "captioning_blip_base": vision_pipeline = pipeline( task="image-to-text", model="Salesforce/blip-image-captioning-base", ) elif model_key == "captioning_blip_large": vision_pipeline = pipeline( task="image-to-text", model="Salesforce/blip-image-captioning-large", ) elif model_key == "vqa_blip_base": vision_pipeline = pipeline( task="visual-question-answering", model="Salesforce/blip-vqa-base", ) elif model_key == "vqa_vilt_b32": vision_pipeline = pipeline( task="visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa", ) else: raise ValueError(f"Неизвестный тип визуальной модели: {model_key}") MODEL_STORE[model_key] = vision_pipeline return vision_pipeline def get_clip_components(clip_key: str) -> Tuple[CLIPModel, CLIPProcessor]: model_store_key_model = f"clip_model_{clip_key}" model_store_key_processor = f"clip_processor_{clip_key}" if model_store_key_model not in MODEL_STORE or model_store_key_processor not in MODEL_STORE: if clip_key == "clip_large_patch14": clip_name = "openai/clip-vit-large-patch14" elif clip_key == "clip_base_patch32": clip_name = "openai/clip-vit-base-patch32" else: raise ValueError(f"Неизвестный вариант CLIP модели: {clip_key}") clip_model = CLIPModel.from_pretrained(clip_name) clip_processor = CLIPProcessor.from_pretrained(clip_name) MODEL_STORE[model_store_key_model] = clip_model MODEL_STORE[model_store_key_processor] = clip_processor clip_model = MODEL_STORE[model_store_key_model] clip_processor = MODEL_STORE[model_store_key_processor] return clip_model, clip_processor def get_silero_tts_model(): if "silero_tts_model" not in MODEL_STORE: silero_model, _ = torch.hub.load( repo_or_dir="snakers4/silero-models", model="silero_tts", language="ru", speaker="ru_v3", ) MODEL_STORE["silero_tts_model"] = silero_model return MODEL_STORE["silero_tts_model"] def get_mms_tts_components(): if "mms_tts_pipeline" not in MODEL_STORE: tts_pipeline = pipeline( task="text-to-speech", model="facebook/mms-tts-rus", ) MODEL_STORE["mms_tts_pipeline"] = tts_pipeline return MODEL_STORE["mms_tts_pipeline"] def get_sam_components() -> Tuple[SamModel, SamProcessor]: if "sam_model" not in MODEL_STORE or "sam_processor" not in MODEL_STORE: sam_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77") sam_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77") MODEL_STORE["sam_model"] = sam_model MODEL_STORE["sam_processor"] = sam_processor sam_model = MODEL_STORE["sam_model"] sam_processor = MODEL_STORE["sam_processor"] return sam_model, sam_processor def classify_audio_file(audio_path: str, model_key: str) -> str: audio_classifier = get_audio_pipeline(model_key) prediction_list = audio_classifier(audio_path) result_lines = ["Топ-5 предсказаний:"] for prediction_index, prediction_item in enumerate(prediction_list[:5], start=1): label_value = prediction_item["label"] score_value = prediction_item["score"] result_lines.append( f"{prediction_index}. {label_value}: {score_value:.4f}" ) return "\n".join(result_lines) def classify_audio_zero_shot_clap(audio_path: str, label_texts: str) -> str: clap_pipeline = get_zero_shot_audio_pipeline() label_list = [ label_item.strip() for label_item in label_texts.split(",") if label_item.strip() ] if not label_list: return "Не задано ни одной текстовой метки для zero-shot классификации." prediction_list = clap_pipeline( audio_path, candidate_labels=label_list, ) result_lines = ["Zero-Shot Audio Classification (CLAP):"] for prediction_index, prediction_item in enumerate(prediction_list, start=1): label_value = prediction_item["label"] score_value = prediction_item["score"] result_lines.append( f"{prediction_index}. {label_value}: {score_value:.4f}" ) return "\n".join(result_lines) def recognize_speech(audio_path: str, model_key: str) -> str: speech_pipeline = get_audio_pipeline(model_key) prediction_result = speech_pipeline(audio_path) return prediction_result["text"] def synthesize_speech(text_value: str, model_key: str): if model_key == "Google TTS": with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object: text_to_speech_engine = gTTS(text=text_value, lang="ru") text_to_speech_engine.save(file_object.name) return file_object.name elif model_key == "mms": model = VitsModel.from_pretrained("facebook/mms-tts-rus") tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus") inputs = tokenizer(text_value, return_tensors="pt") with torch.no_grad(): output = model(**inputs).waveform with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: sf.write(f.name, output.numpy().squeeze(), model.config.sampling_rate) return f.name raise ValueError(f"Неизвестная модель: {model_key}") def detect_objects_on_image(image_object, model_key: str): detector_pipeline = get_vision_pipeline(model_key) detection_results = detector_pipeline(image_object) drawer_object = ImageDraw.Draw(image_object) for detection_item in detection_results: box_data = detection_item["box"] label_value = detection_item["label"] score_value = detection_item["score"] drawer_object.rectangle( [ box_data["xmin"], box_data["ymin"], box_data["xmax"], box_data["ymax"], ], outline="red", width=3, ) drawer_object.text( (box_data["xmin"], box_data["ymin"]), f"{label_value}: {score_value:.2f}", fill="red", ) return image_object def segment_image(image_object): segmentation_pipeline = get_vision_pipeline("segmentation") segmentation_results = segmentation_pipeline(image_object) return segmentation_results[0]["mask"] def estimate_image_depth(image_object): depth_pipeline = get_vision_pipeline("depth_estimation") depth_output = depth_pipeline(image_object) predicted_depth_tensor = depth_output["predicted_depth"] if predicted_depth_tensor.ndim == 3: predicted_depth_tensor = predicted_depth_tensor.unsqueeze(1) elif predicted_depth_tensor.ndim == 2: predicted_depth_tensor = predicted_depth_tensor.unsqueeze(0).unsqueeze(0) else: raise ValueError( f"Неожиданная размерность predicted_depth: {predicted_depth_tensor.shape}" ) resized_depth_tensor = torch_functional.interpolate( predicted_depth_tensor, size=image_object.size[::-1], mode="bicubic", align_corners=False, ) depth_array = resized_depth_tensor.squeeze().cpu().numpy() max_value = float(depth_array.max()) if max_value <= 0.0: return Image.new("L", image_object.size, color=0) normalized_depth_array = (depth_array * 255.0 / max_value).astype("uint8") depth_image = Image.fromarray(normalized_depth_array, mode="L") return depth_image def generate_image_caption(image_object, model_key: str) -> str: caption_pipeline = get_vision_pipeline(model_key) caption_result = caption_pipeline(image_object) return caption_result[0]["generated_text"] def answer_visual_question(image_object, question_text: str, model_key: str) -> str: if image_object is None: return "Пожалуйста, сначала загрузите изображение." if not question_text.strip(): return "Пожалуйста, введите вопрос об изображении." if model_key == "vqa_blip_base": blip_model, blip_processor = get_blip_vqa_components() inputs = blip_processor( images=image_object, text=question_text, return_tensors="pt", ) with torch.no_grad(): output_ids = blip_model.generate(**inputs) decoded_answers = blip_processor.batch_decode( output_ids, skip_special_tokens=True, ) answer_text = decoded_answers[0] if decoded_answers else "" return answer_text or "Модель не смогла сгенерировать ответ." vqa_pipeline = get_vision_pipeline(model_key) vqa_result = vqa_pipeline( image=image_object, question=question_text, ) top_item = vqa_result[0] answer_text = top_item["answer"] confidence_value = top_item["score"] return f"{answer_text} (confidence: {confidence_value:.3f})" def perform_zero_shot_classification( image_object, class_texts: str, clip_key: str, ) -> str: clip_model, clip_processor = get_clip_components(clip_key) class_list = [ class_name.strip() for class_name in class_texts.split(",") if class_name.strip() ] if not class_list: return "Не задано ни одного класса для классификации." input_batch = clip_processor( text=class_list, images=image_object, return_tensors="pt", padding=True, ) with torch.no_grad(): clip_outputs = clip_model(**input_batch) logits_per_image = clip_outputs.logits_per_image probability_tensor = logits_per_image.softmax(dim=1) result_lines = ["Zero-Shot Classification Results:"] for class_index, class_name in enumerate(class_list): probability_value = probability_tensor[0][class_index].item() result_lines.append(f"{class_name}: {probability_value:.4f}") return "\n".join(result_lines) def retrieve_best_image( gallery_value: Any, query_text: str, clip_key: str, ) -> Tuple[str, Image.Image | None]: image_list = _normalize_gallery_images(gallery_value) if not image_list or not query_text.strip(): return "Пожалуйста, загрузите изображения и введите запрос", None clip_model, clip_processor = get_clip_components(clip_key) image_inputs = clip_processor( images=image_list, return_tensors="pt", padding=True, ) with torch.no_grad(): image_features = clip_model.get_image_features(**image_inputs) image_features = image_features / image_features.norm( dim=-1, keepdim=True, ) text_inputs = clip_processor( text=[query_text], return_tensors="pt", padding=True, ) with torch.no_grad(): text_features = clip_model.get_text_features(**text_inputs) text_features = text_features / text_features.norm( dim=-1, keepdim=True, ) similarity_tensor = image_features @ text_features.T best_index_tensor = similarity_tensor.argmax() best_index_value = best_index_tensor.item() best_score_value = similarity_tensor[best_index_value].item() description_text = ( f"Лучшее изображение: #{best_index_value + 1} " f"(схожесть: {best_score_value:.4f})" ) return description_text, image_list[best_index_value] def segment_image_with_sam_points( image_object, point_coordinates_list: List[List[int]], ) -> Image.Image: if image_object is None: raise ValueError("Изображение не передано в segment_image_with_sam_points") if not point_coordinates_list: return Image.new("L", image_object.size, color=0) sam_model, sam_processor = get_sam_components() batched_points: List[List[List[int]]] = [point_coordinates_list] batched_labels: List[List[int]] = [[1 for _ in point_coordinates_list]] sam_inputs = sam_processor( image=image_object, input_points=batched_points, input_labels=batched_labels, return_tensors="pt", ) with torch.no_grad(): sam_outputs = sam_model(**sam_inputs, multimask_output=True) processed_masks_list = sam_processor.image_processor.post_process_masks( sam_outputs.pred_masks.squeeze(1).cpu(), sam_inputs["original_sizes"].cpu(), sam_inputs["reshaped_input_sizes"].cpu(), ) batch_masks_tensor = processed_masks_list[0] if batch_masks_tensor.ndim != 3 or batch_masks_tensor.shape[0] == 0: return Image.new("L", image_object.size, color=0) first_mask_tensor = batch_masks_tensor[0] mask_array = first_mask_tensor.numpy() binary_mask_array = (mask_array > 0.5).astype("uint8") * 255 mask_image = Image.fromarray(binary_mask_array, mode="L") return mask_image def segment_image_with_sam_points_ui(image_object, coordinates_text: str) -> Image.Image: if image_object is None: return None coordinates_text_clean = coordinates_text.strip() if not coordinates_text_clean: return Image.new("L", image_object.size, color=0) point_coordinates_list: List[List[int]] = [] for raw_pair in coordinates_text_clean.replace("\n", ";").split(";"): raw_pair_clean = raw_pair.strip() if not raw_pair_clean: continue parts = raw_pair_clean.split(",") if len(parts) != 2: continue try: x_value = int(parts[0].strip()) y_value = int(parts[1].strip()) except ValueError: continue point_coordinates_list.append([x_value, y_value]) if not point_coordinates_list: return Image.new("L", image_object.size, color=0) return segment_image_with_sam_points(image_object, point_coordinates_list) def parse_point_coordinates_text(coordinates_text: str) -> List[List[int]]: if not coordinates_text.strip(): return [] point_list: List[List[int]] = [] for raw_pair in coordinates_text.split(";"): cleaned_pair = raw_pair.strip() if not cleaned_pair: continue coordinate_parts = cleaned_pair.split(",") if len(coordinate_parts) != 2: continue try: x_value = int(coordinate_parts[0].strip()) y_value = int(coordinate_parts[1].strip()) except ValueError: continue point_list.append([x_value, y_value]) return point_list def build_interface(): with gr.Blocks(title="Multimodal AI Demo") as demo_block: gr.Markdown("# AI модели") with gr.Tab("Детекция объектов"): gr.Markdown("## Детекция объектов") with gr.Row(): object_input_image = gr.Image( label="Загрузите изображение", type="pil", ) object_model_selector = gr.Dropdown( choices=[ "object_detection_conditional_detr", "object_detection_yolos_small", ], label="Модель", value="object_detection_conditional_detr", info=( "object_detection_conditional_detr - microsoft/conditional-detr-resnet-50\n" "object_detection_yolos_small - hustvl/yolos-small" ), ) object_detect_button = gr.Button("Применить") object_output_image = gr.Image( label="Результат", ) object_detect_button.click( fn=detect_objects_on_image, inputs=[object_input_image, object_model_selector], outputs=object_output_image, ) ##with gr.Tab("Сегментация"): ## gr.Markdown("## Сегментация") ## with gr.Row(): ## segmentation_input_image = gr.Image( ## label="Загрузите изображение", ## type="pil", ## ) ## segmentation_button = gr.Button("Применить") ## ## segmentation_output_image = gr.Image( ## label="Маска", ## ) ## ## segmentation_button.click( ## fn=segment_image, ## inputs=segmentation_input_image, ## outputs=segmentation_output_image, ## ) with gr.Tab("Глубина изображения"): gr.Markdown("## Глубина (Depth Estimation)") with gr.Row(): depth_input_image = gr.Image( label="Загрузите изображение", type="pil", ) depth_button = gr.Button("Применить") depth_output_image = gr.Image( label="Глубины", ) depth_button.click( fn=estimate_image_depth, inputs=depth_input_image, outputs=depth_output_image, ) with gr.Tab("Описание изображений"): gr.Markdown("## Описание изображений") with gr.Row(): caption_input_image = gr.Image( label="Загрузите изображение", type="pil", ) caption_model_selector = gr.Dropdown( choices=[ "captioning_blip_base", "captioning_blip_large", ], label="Модель", value="captioning_blip_base", info=( "captioning_blip_base - Salesforce/blip-image-captioning-base (курс)\n" "captioning_blip_large - Salesforce/blip-image-captioning-large" ), ) caption_button = gr.Button("Применить") caption_output_text = gr.Textbox( label="Описание изображения", lines=3, ) caption_button.click( fn=generate_image_caption, inputs=[caption_input_image, caption_model_selector], outputs=caption_output_text, ) with gr.Tab("Вопросы к изображению"): gr.Markdown("## Visual Question Answering") with gr.Row(): vqa_input_image = gr.Image( label="Загрузите изображение", type="pil", ) vqa_question_text = gr.Textbox( label="Вопрос", placeholder="Вопрос", lines=2, ) vqa_model_selector = gr.Dropdown( choices=[ "vqa_blip_base", "vqa_vilt_b32", ], label="Модель", value="vqa_blip_base", info=( "vqa_blip_base - Salesforce/blip-vqa-base (курс)\n" "vqa_vilt_b32 - dandelin/vilt-b32-finetuned-vqa" ), ) vqa_button = gr.Button("Ответить на вопрос") vqa_output_text = gr.Textbox( label="Ответ", lines=3, ) vqa_button.click( fn=answer_visual_question, inputs=[vqa_input_image, vqa_question_text, vqa_model_selector], outputs=vqa_output_text, ) with gr.Tab("Zero-Shot классификация"): gr.Markdown("## Zero-Shot классификация") with gr.Row(): zero_shot_input_image = gr.Image( label="Загрузите изображение", type="pil", ) zero_shot_classes_text = gr.Textbox( label="Классы для классификации (через запятую)", placeholder="человек, машина, дерево, здание, животное", lines=2, ) clip_model_selector = gr.Dropdown( choices=[ "clip_large_patch14", "clip_base_patch32", ], label="модель", value="clip_large_patch14", info=( "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n" "clip_base_patch32 - openai/clip-vit-base-patch32" ), ) zero_shot_button = gr.Button("Применить") zero_shot_output_text = gr.Textbox( label="Результаты", lines=10, ) zero_shot_button.click( fn=perform_zero_shot_classification, inputs=[zero_shot_input_image, zero_shot_classes_text, clip_model_selector], outputs=zero_shot_output_text, ) with gr.Tab("Поиск изображений в папке"): gr.Markdown("## Поиск изображений в папке") with gr.Row(): retrieval_dir = gr.File( label="Загрузите папку с изображениями", file_count="directory", file_types=["image"], type="filepath", ) retrieval_query_text = gr.Textbox( label="Текстовый запрос", placeholder="описание того, что вы ищете...", lines=2, ) retrieval_clip_selector = gr.Dropdown( choices=[ "clip_large_patch14", "clip_base_patch32", ], label="модель", value="clip_large_patch14", info=( "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n" "clip_base_patch32 - openai/clip-vit-base-patch32 (альтернатива)" ), ) retrieval_button = gr.Button("Поиск") retrieval_output_text = gr.Textbox( label="Результат", ) retrieval_output_image = gr.Image( label="Наиболее подходящее изображение", ) retrieval_button.click( fn=retrieve_best_image, inputs=[retrieval_dir, retrieval_query_text, retrieval_clip_selector], outputs=[retrieval_output_text, retrieval_output_image], ) return demo_block if __name__ == "__main__": interface_block = build_interface() interface_block.launch(share=True)