Spaces:
Running
Running
| import tempfile | |
| from typing import List, Tuple, Any | |
| import gradio as gr | |
| import soundfile as sf | |
| import torch | |
| import torch.nn.functional as torch_functional | |
| from gtts import gTTS | |
| from PIL import Image, ImageDraw | |
| from transformers import ( | |
| AutoTokenizer, | |
| CLIPModel, | |
| CLIPProcessor, | |
| SamModel, | |
| SamProcessor, | |
| VitsModel, | |
| pipeline, | |
| BlipForQuestionAnswering, | |
| BlipProcessor, | |
| ) | |
| MODEL_STORE = {} | |
| def _normalize_gallery_images(gallery_value: Any) -> List[Image.Image]: | |
| if not gallery_value: | |
| return [] | |
| normalized_images: List[Image.Image] = [] | |
| for item in gallery_value: | |
| if isinstance(item, Image.Image): | |
| normalized_images.append(item) | |
| continue | |
| if isinstance(item, str): | |
| try: | |
| image_object = Image.open(item).convert("RGB") | |
| normalized_images.append(image_object) | |
| except Exception: | |
| continue | |
| continue | |
| if isinstance(item, (list, tuple)) and item: | |
| candidate = item[0] | |
| if isinstance(candidate, Image.Image): | |
| normalized_images.append(candidate) | |
| continue | |
| if isinstance(item, dict): | |
| candidate = item.get("image") or item.get("value") | |
| if isinstance(candidate, Image.Image): | |
| normalized_images.append(candidate) | |
| continue | |
| return normalized_images | |
| def get_audio_pipeline(model_key: str): | |
| if model_key in MODEL_STORE: | |
| return MODEL_STORE[model_key] | |
| if model_key == "whisper": | |
| audio_pipeline = pipeline( | |
| task="automatic-speech-recognition", | |
| model="distil-whisper/distil-small.en", | |
| ) | |
| elif model_key == "wav2vec2": | |
| audio_pipeline = pipeline( | |
| task="automatic-speech-recognition", | |
| model="openai/whisper-small", | |
| ) | |
| elif model_key == "audio_classifier": | |
| audio_pipeline = pipeline( | |
| task="audio-classification", | |
| model="MIT/ast-finetuned-audioset-10-10-0.4593", | |
| ) | |
| elif model_key == "emotion_classifier": | |
| audio_pipeline = pipeline( | |
| task="audio-classification", | |
| model="superb/hubert-large-superb-er", | |
| ) | |
| else: | |
| raise ValueError(f"Неизвестный тип аудио модели: {model_key}") | |
| MODEL_STORE[model_key] = audio_pipeline | |
| return audio_pipeline | |
| def get_zero_shot_audio_pipeline(): | |
| if "audio_zero_shot_clap" not in MODEL_STORE: | |
| zero_shot_pipeline = pipeline( | |
| task="zero-shot-audio-classification", | |
| model="laion/clap-htsat-unfused", | |
| ) | |
| MODEL_STORE["audio_zero_shot_clap"] = zero_shot_pipeline | |
| return MODEL_STORE["audio_zero_shot_clap"] | |
| def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]: | |
| if "blip_vqa_model" not in MODEL_STORE or "blip_vqa_processor" not in MODEL_STORE: | |
| blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") | |
| MODEL_STORE["blip_vqa_model"] = blip_model | |
| MODEL_STORE["blip_vqa_processor"] = blip_processor | |
| blip_model = MODEL_STORE["blip_vqa_model"] | |
| blip_processor = MODEL_STORE["blip_vqa_processor"] | |
| return blip_model, blip_processor | |
| def get_vision_pipeline(model_key: str): | |
| if model_key in MODEL_STORE: | |
| return MODEL_STORE[model_key] | |
| if model_key == "object_detection_conditional_detr": | |
| vision_pipeline = pipeline( | |
| task="object-detection", | |
| model="microsoft/conditional-detr-resnet-50", | |
| ) | |
| elif model_key == "object_detection_yolos_small": | |
| vision_pipeline = pipeline( | |
| task="object-detection", | |
| model="hustvl/yolos-small", | |
| ) | |
| elif model_key == "segmentation": | |
| vision_pipeline = pipeline( | |
| task="image-segmentation", | |
| model="nvidia/segformer-b0-finetuned-ade-512-512", | |
| ) | |
| elif model_key == "depth_estimation": | |
| vision_pipeline = pipeline( | |
| task="depth-estimation", | |
| model="Intel/dpt-hybrid-midas", | |
| ) | |
| elif model_key == "captioning_blip_base": | |
| vision_pipeline = pipeline( | |
| task="image-to-text", | |
| model="Salesforce/blip-image-captioning-base", | |
| ) | |
| elif model_key == "captioning_blip_large": | |
| vision_pipeline = pipeline( | |
| task="image-to-text", | |
| model="Salesforce/blip-image-captioning-large", | |
| ) | |
| elif model_key == "vqa_blip_base": | |
| vision_pipeline = pipeline( | |
| task="visual-question-answering", | |
| model="Salesforce/blip-vqa-base", | |
| ) | |
| elif model_key == "vqa_vilt_b32": | |
| vision_pipeline = pipeline( | |
| task="visual-question-answering", | |
| model="dandelin/vilt-b32-finetuned-vqa", | |
| ) | |
| else: | |
| raise ValueError(f"Неизвестный тип визуальной модели: {model_key}") | |
| MODEL_STORE[model_key] = vision_pipeline | |
| return vision_pipeline | |
| def get_clip_components(clip_key: str) -> Tuple[CLIPModel, CLIPProcessor]: | |
| model_store_key_model = f"clip_model_{clip_key}" | |
| model_store_key_processor = f"clip_processor_{clip_key}" | |
| if model_store_key_model not in MODEL_STORE or model_store_key_processor not in MODEL_STORE: | |
| if clip_key == "clip_large_patch14": | |
| clip_name = "openai/clip-vit-large-patch14" | |
| elif clip_key == "clip_base_patch32": | |
| clip_name = "openai/clip-vit-base-patch32" | |
| else: | |
| raise ValueError(f"Неизвестный вариант CLIP модели: {clip_key}") | |
| clip_model = CLIPModel.from_pretrained(clip_name) | |
| clip_processor = CLIPProcessor.from_pretrained(clip_name) | |
| MODEL_STORE[model_store_key_model] = clip_model | |
| MODEL_STORE[model_store_key_processor] = clip_processor | |
| clip_model = MODEL_STORE[model_store_key_model] | |
| clip_processor = MODEL_STORE[model_store_key_processor] | |
| return clip_model, clip_processor | |
| def get_silero_tts_model(): | |
| if "silero_tts_model" not in MODEL_STORE: | |
| silero_model, _ = torch.hub.load( | |
| repo_or_dir="snakers4/silero-models", | |
| model="silero_tts", | |
| language="ru", | |
| speaker="ru_v3", | |
| ) | |
| MODEL_STORE["silero_tts_model"] = silero_model | |
| return MODEL_STORE["silero_tts_model"] | |
| def get_mms_tts_components(): | |
| if "mms_tts_pipeline" not in MODEL_STORE: | |
| tts_pipeline = pipeline( | |
| task="text-to-speech", | |
| model="facebook/mms-tts-rus", | |
| ) | |
| MODEL_STORE["mms_tts_pipeline"] = tts_pipeline | |
| return MODEL_STORE["mms_tts_pipeline"] | |
| def get_sam_components() -> Tuple[SamModel, SamProcessor]: | |
| if "sam_model" not in MODEL_STORE or "sam_processor" not in MODEL_STORE: | |
| sam_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77") | |
| sam_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77") | |
| MODEL_STORE["sam_model"] = sam_model | |
| MODEL_STORE["sam_processor"] = sam_processor | |
| sam_model = MODEL_STORE["sam_model"] | |
| sam_processor = MODEL_STORE["sam_processor"] | |
| return sam_model, sam_processor | |
| def classify_audio_file(audio_path: str, model_key: str) -> str: | |
| audio_classifier = get_audio_pipeline(model_key) | |
| prediction_list = audio_classifier(audio_path) | |
| result_lines = ["Топ-5 предсказаний:"] | |
| for prediction_index, prediction_item in enumerate(prediction_list[:5], start=1): | |
| label_value = prediction_item["label"] | |
| score_value = prediction_item["score"] | |
| result_lines.append( | |
| f"{prediction_index}. {label_value}: {score_value:.4f}" | |
| ) | |
| return "\n".join(result_lines) | |
| def classify_audio_zero_shot_clap(audio_path: str, label_texts: str) -> str: | |
| clap_pipeline = get_zero_shot_audio_pipeline() | |
| label_list = [ | |
| label_item.strip() | |
| for label_item in label_texts.split(",") | |
| if label_item.strip() | |
| ] | |
| if not label_list: | |
| return "Не задано ни одной текстовой метки для zero-shot классификации." | |
| prediction_list = clap_pipeline( | |
| audio_path, | |
| candidate_labels=label_list, | |
| ) | |
| result_lines = ["Zero-Shot Audio Classification (CLAP):"] | |
| for prediction_index, prediction_item in enumerate(prediction_list, start=1): | |
| label_value = prediction_item["label"] | |
| score_value = prediction_item["score"] | |
| result_lines.append( | |
| f"{prediction_index}. {label_value}: {score_value:.4f}" | |
| ) | |
| return "\n".join(result_lines) | |
| def recognize_speech(audio_path: str, model_key: str) -> str: | |
| speech_pipeline = get_audio_pipeline(model_key) | |
| prediction_result = speech_pipeline(audio_path) | |
| return prediction_result["text"] | |
| def synthesize_speech(text_value: str, model_key: str): | |
| if model_key == "Google TTS": | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object: | |
| text_to_speech_engine = gTTS(text=text_value, lang="ru") | |
| text_to_speech_engine.save(file_object.name) | |
| return file_object.name | |
| elif model_key == "mms": | |
| model = VitsModel.from_pretrained("facebook/mms-tts-rus") | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus") | |
| inputs = tokenizer(text_value, return_tensors="pt") | |
| with torch.no_grad(): | |
| output = model(**inputs).waveform | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| sf.write(f.name, output.numpy().squeeze(), model.config.sampling_rate) | |
| return f.name | |
| raise ValueError(f"Неизвестная модель: {model_key}") | |
| def detect_objects_on_image(image_object, model_key: str): | |
| detector_pipeline = get_vision_pipeline(model_key) | |
| detection_results = detector_pipeline(image_object) | |
| drawer_object = ImageDraw.Draw(image_object) | |
| for detection_item in detection_results: | |
| box_data = detection_item["box"] | |
| label_value = detection_item["label"] | |
| score_value = detection_item["score"] | |
| drawer_object.rectangle( | |
| [ | |
| box_data["xmin"], | |
| box_data["ymin"], | |
| box_data["xmax"], | |
| box_data["ymax"], | |
| ], | |
| outline="red", | |
| width=3, | |
| ) | |
| drawer_object.text( | |
| (box_data["xmin"], box_data["ymin"]), | |
| f"{label_value}: {score_value:.2f}", | |
| fill="red", | |
| ) | |
| return image_object | |
| def segment_image(image_object): | |
| segmentation_pipeline = get_vision_pipeline("segmentation") | |
| segmentation_results = segmentation_pipeline(image_object) | |
| return segmentation_results[0]["mask"] | |
| def estimate_image_depth(image_object): | |
| depth_pipeline = get_vision_pipeline("depth_estimation") | |
| depth_output = depth_pipeline(image_object) | |
| predicted_depth_tensor = depth_output["predicted_depth"] | |
| if predicted_depth_tensor.ndim == 3: | |
| predicted_depth_tensor = predicted_depth_tensor.unsqueeze(1) | |
| elif predicted_depth_tensor.ndim == 2: | |
| predicted_depth_tensor = predicted_depth_tensor.unsqueeze(0).unsqueeze(0) | |
| else: | |
| raise ValueError( | |
| f"Неожиданная размерность predicted_depth: {predicted_depth_tensor.shape}" | |
| ) | |
| resized_depth_tensor = torch_functional.interpolate( | |
| predicted_depth_tensor, | |
| size=image_object.size[::-1], | |
| mode="bicubic", | |
| align_corners=False, | |
| ) | |
| depth_array = resized_depth_tensor.squeeze().cpu().numpy() | |
| max_value = float(depth_array.max()) | |
| if max_value <= 0.0: | |
| return Image.new("L", image_object.size, color=0) | |
| normalized_depth_array = (depth_array * 255.0 / max_value).astype("uint8") | |
| depth_image = Image.fromarray(normalized_depth_array, mode="L") | |
| return depth_image | |
| def generate_image_caption(image_object, model_key: str) -> str: | |
| caption_pipeline = get_vision_pipeline(model_key) | |
| caption_result = caption_pipeline(image_object) | |
| return caption_result[0]["generated_text"] | |
| def answer_visual_question(image_object, question_text: str, model_key: str) -> str: | |
| if image_object is None: | |
| return "Пожалуйста, сначала загрузите изображение." | |
| if not question_text.strip(): | |
| return "Пожалуйста, введите вопрос об изображении." | |
| if model_key == "vqa_blip_base": | |
| blip_model, blip_processor = get_blip_vqa_components() | |
| inputs = blip_processor( | |
| images=image_object, | |
| text=question_text, | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| output_ids = blip_model.generate(**inputs) | |
| decoded_answers = blip_processor.batch_decode( | |
| output_ids, | |
| skip_special_tokens=True, | |
| ) | |
| answer_text = decoded_answers[0] if decoded_answers else "" | |
| return answer_text or "Модель не смогла сгенерировать ответ." | |
| vqa_pipeline = get_vision_pipeline(model_key) | |
| vqa_result = vqa_pipeline( | |
| image=image_object, | |
| question=question_text, | |
| ) | |
| top_item = vqa_result[0] | |
| answer_text = top_item["answer"] | |
| confidence_value = top_item["score"] | |
| return f"{answer_text} (confidence: {confidence_value:.3f})" | |
| def perform_zero_shot_classification( | |
| image_object, | |
| class_texts: str, | |
| clip_key: str, | |
| ) -> str: | |
| clip_model, clip_processor = get_clip_components(clip_key) | |
| class_list = [ | |
| class_name.strip() | |
| for class_name in class_texts.split(",") | |
| if class_name.strip() | |
| ] | |
| if not class_list: | |
| return "Не задано ни одного класса для классификации." | |
| input_batch = clip_processor( | |
| text=class_list, | |
| images=image_object, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| with torch.no_grad(): | |
| clip_outputs = clip_model(**input_batch) | |
| logits_per_image = clip_outputs.logits_per_image | |
| probability_tensor = logits_per_image.softmax(dim=1) | |
| result_lines = ["Zero-Shot Classification Results:"] | |
| for class_index, class_name in enumerate(class_list): | |
| probability_value = probability_tensor[0][class_index].item() | |
| result_lines.append(f"{class_name}: {probability_value:.4f}") | |
| return "\n".join(result_lines) | |
| def retrieve_best_image( | |
| gallery_value: Any, | |
| query_text: str, | |
| clip_key: str, | |
| ) -> Tuple[str, Image.Image | None]: | |
| image_list = _normalize_gallery_images(gallery_value) | |
| if not image_list or not query_text.strip(): | |
| return "Пожалуйста, загрузите изображения и введите запрос", None | |
| clip_model, clip_processor = get_clip_components(clip_key) | |
| image_inputs = clip_processor( | |
| images=image_list, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| with torch.no_grad(): | |
| image_features = clip_model.get_image_features(**image_inputs) | |
| image_features = image_features / image_features.norm( | |
| dim=-1, | |
| keepdim=True, | |
| ) | |
| text_inputs = clip_processor( | |
| text=[query_text], | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| with torch.no_grad(): | |
| text_features = clip_model.get_text_features(**text_inputs) | |
| text_features = text_features / text_features.norm( | |
| dim=-1, | |
| keepdim=True, | |
| ) | |
| similarity_tensor = image_features @ text_features.T | |
| best_index_tensor = similarity_tensor.argmax() | |
| best_index_value = best_index_tensor.item() | |
| best_score_value = similarity_tensor[best_index_value].item() | |
| description_text = ( | |
| f"Лучшее изображение: #{best_index_value + 1} " | |
| f"(схожесть: {best_score_value:.4f})" | |
| ) | |
| return description_text, image_list[best_index_value] | |
| def segment_image_with_sam_points( | |
| image_object, | |
| point_coordinates_list: List[List[int]], | |
| ) -> Image.Image: | |
| if image_object is None: | |
| raise ValueError("Изображение не передано в segment_image_with_sam_points") | |
| if not point_coordinates_list: | |
| return Image.new("L", image_object.size, color=0) | |
| sam_model, sam_processor = get_sam_components() | |
| batched_points: List[List[List[int]]] = [point_coordinates_list] | |
| batched_labels: List[List[int]] = [[1 for _ in point_coordinates_list]] | |
| sam_inputs = sam_processor( | |
| image=image_object, | |
| input_points=batched_points, | |
| input_labels=batched_labels, | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| sam_outputs = sam_model(**sam_inputs, multimask_output=True) | |
| processed_masks_list = sam_processor.image_processor.post_process_masks( | |
| sam_outputs.pred_masks.squeeze(1).cpu(), | |
| sam_inputs["original_sizes"].cpu(), | |
| sam_inputs["reshaped_input_sizes"].cpu(), | |
| ) | |
| batch_masks_tensor = processed_masks_list[0] | |
| if batch_masks_tensor.ndim != 3 or batch_masks_tensor.shape[0] == 0: | |
| return Image.new("L", image_object.size, color=0) | |
| first_mask_tensor = batch_masks_tensor[0] | |
| mask_array = first_mask_tensor.numpy() | |
| binary_mask_array = (mask_array > 0.5).astype("uint8") * 255 | |
| mask_image = Image.fromarray(binary_mask_array, mode="L") | |
| return mask_image | |
| def segment_image_with_sam_points_ui(image_object, coordinates_text: str) -> Image.Image: | |
| if image_object is None: | |
| return None | |
| coordinates_text_clean = coordinates_text.strip() | |
| if not coordinates_text_clean: | |
| return Image.new("L", image_object.size, color=0) | |
| point_coordinates_list: List[List[int]] = [] | |
| for raw_pair in coordinates_text_clean.replace("\n", ";").split(";"): | |
| raw_pair_clean = raw_pair.strip() | |
| if not raw_pair_clean: | |
| continue | |
| parts = raw_pair_clean.split(",") | |
| if len(parts) != 2: | |
| continue | |
| try: | |
| x_value = int(parts[0].strip()) | |
| y_value = int(parts[1].strip()) | |
| except ValueError: | |
| continue | |
| point_coordinates_list.append([x_value, y_value]) | |
| if not point_coordinates_list: | |
| return Image.new("L", image_object.size, color=0) | |
| return segment_image_with_sam_points(image_object, point_coordinates_list) | |
| def parse_point_coordinates_text(coordinates_text: str) -> List[List[int]]: | |
| if not coordinates_text.strip(): | |
| return [] | |
| point_list: List[List[int]] = [] | |
| for raw_pair in coordinates_text.split(";"): | |
| cleaned_pair = raw_pair.strip() | |
| if not cleaned_pair: | |
| continue | |
| coordinate_parts = cleaned_pair.split(",") | |
| if len(coordinate_parts) != 2: | |
| continue | |
| try: | |
| x_value = int(coordinate_parts[0].strip()) | |
| y_value = int(coordinate_parts[1].strip()) | |
| except ValueError: | |
| continue | |
| point_list.append([x_value, y_value]) | |
| return point_list | |
| def build_interface(): | |
| with gr.Blocks(title="Multimodal AI Demo") as demo_block: | |
| gr.Markdown("# AI модели") | |
| with gr.Tab("Детекция объектов"): | |
| gr.Markdown("## Детекция объектов") | |
| with gr.Row(): | |
| object_input_image = gr.Image( | |
| label="Загрузите изображение", | |
| type="pil", | |
| ) | |
| object_model_selector = gr.Dropdown( | |
| choices=[ | |
| "object_detection_conditional_detr", | |
| "object_detection_yolos_small", | |
| ], | |
| label="Модель", | |
| value="object_detection_conditional_detr", | |
| info=( | |
| "object_detection_conditional_detr - microsoft/conditional-detr-resnet-50\n" | |
| "object_detection_yolos_small - hustvl/yolos-small" | |
| ), | |
| ) | |
| object_detect_button = gr.Button("Применить") | |
| object_output_image = gr.Image( | |
| label="Результат", | |
| ) | |
| object_detect_button.click( | |
| fn=detect_objects_on_image, | |
| inputs=[object_input_image, object_model_selector], | |
| outputs=object_output_image, | |
| ) | |
| ##with gr.Tab("Сегментация"): | |
| ## gr.Markdown("## Сегментация") | |
| ## with gr.Row(): | |
| ## segmentation_input_image = gr.Image( | |
| ## label="Загрузите изображение", | |
| ## type="pil", | |
| ## ) | |
| ## segmentation_button = gr.Button("Применить") | |
| ## | |
| ## segmentation_output_image = gr.Image( | |
| ## label="Маска", | |
| ## ) | |
| ## | |
| ## segmentation_button.click( | |
| ## fn=segment_image, | |
| ## inputs=segmentation_input_image, | |
| ## outputs=segmentation_output_image, | |
| ## ) | |
| with gr.Tab("Глубина изображения"): | |
| gr.Markdown("## Глубина (Depth Estimation)") | |
| with gr.Row(): | |
| depth_input_image = gr.Image( | |
| label="Загрузите изображение", | |
| type="pil", | |
| ) | |
| depth_button = gr.Button("Применить") | |
| depth_output_image = gr.Image( | |
| label="Глубины", | |
| ) | |
| depth_button.click( | |
| fn=estimate_image_depth, | |
| inputs=depth_input_image, | |
| outputs=depth_output_image, | |
| ) | |
| with gr.Tab("Описание изображений"): | |
| gr.Markdown("## Описание изображений") | |
| with gr.Row(): | |
| caption_input_image = gr.Image( | |
| label="Загрузите изображение", | |
| type="pil", | |
| ) | |
| caption_model_selector = gr.Dropdown( | |
| choices=[ | |
| "captioning_blip_base", | |
| "captioning_blip_large", | |
| ], | |
| label="Модель", | |
| value="captioning_blip_base", | |
| info=( | |
| "captioning_blip_base - Salesforce/blip-image-captioning-base (курс)\n" | |
| "captioning_blip_large - Salesforce/blip-image-captioning-large" | |
| ), | |
| ) | |
| caption_button = gr.Button("Применить") | |
| caption_output_text = gr.Textbox( | |
| label="Описание изображения", | |
| lines=3, | |
| ) | |
| caption_button.click( | |
| fn=generate_image_caption, | |
| inputs=[caption_input_image, caption_model_selector], | |
| outputs=caption_output_text, | |
| ) | |
| with gr.Tab("Вопросы к изображению"): | |
| gr.Markdown("## Visual Question Answering") | |
| with gr.Row(): | |
| vqa_input_image = gr.Image( | |
| label="Загрузите изображение", | |
| type="pil", | |
| ) | |
| vqa_question_text = gr.Textbox( | |
| label="Вопрос", | |
| placeholder="Вопрос", | |
| lines=2, | |
| ) | |
| vqa_model_selector = gr.Dropdown( | |
| choices=[ | |
| "vqa_blip_base", | |
| "vqa_vilt_b32", | |
| ], | |
| label="Модель", | |
| value="vqa_blip_base", | |
| info=( | |
| "vqa_blip_base - Salesforce/blip-vqa-base (курс)\n" | |
| "vqa_vilt_b32 - dandelin/vilt-b32-finetuned-vqa" | |
| ), | |
| ) | |
| vqa_button = gr.Button("Ответить на вопрос") | |
| vqa_output_text = gr.Textbox( | |
| label="Ответ", | |
| lines=3, | |
| ) | |
| vqa_button.click( | |
| fn=answer_visual_question, | |
| inputs=[vqa_input_image, vqa_question_text, vqa_model_selector], | |
| outputs=vqa_output_text, | |
| ) | |
| with gr.Tab("Zero-Shot классификация"): | |
| gr.Markdown("## Zero-Shot классификация") | |
| with gr.Row(): | |
| zero_shot_input_image = gr.Image( | |
| label="Загрузите изображение", | |
| type="pil", | |
| ) | |
| zero_shot_classes_text = gr.Textbox( | |
| label="Классы для классификации (через запятую)", | |
| placeholder="человек, машина, дерево, здание, животное", | |
| lines=2, | |
| ) | |
| clip_model_selector = gr.Dropdown( | |
| choices=[ | |
| "clip_large_patch14", | |
| "clip_base_patch32", | |
| ], | |
| label="модель", | |
| value="clip_large_patch14", | |
| info=( | |
| "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n" | |
| "clip_base_patch32 - openai/clip-vit-base-patch32" | |
| ), | |
| ) | |
| zero_shot_button = gr.Button("Применить") | |
| zero_shot_output_text = gr.Textbox( | |
| label="Результаты", | |
| lines=10, | |
| ) | |
| zero_shot_button.click( | |
| fn=perform_zero_shot_classification, | |
| inputs=[zero_shot_input_image, zero_shot_classes_text, clip_model_selector], | |
| outputs=zero_shot_output_text, | |
| ) | |
| with gr.Tab("Поиск изображений в папке"): | |
| gr.Markdown("## Поиск изображений в папке") | |
| with gr.Row(): | |
| retrieval_dir = gr.File( | |
| label="Загрузите папку с изображениями", | |
| file_count="directory", | |
| file_types=["image"], | |
| type="filepath", | |
| ) | |
| retrieval_query_text = gr.Textbox( | |
| label="Текстовый запрос", | |
| placeholder="описание того, что вы ищете...", | |
| lines=2, | |
| ) | |
| retrieval_clip_selector = gr.Dropdown( | |
| choices=[ | |
| "clip_large_patch14", | |
| "clip_base_patch32", | |
| ], | |
| label="модель", | |
| value="clip_large_patch14", | |
| info=( | |
| "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n" | |
| "clip_base_patch32 - openai/clip-vit-base-patch32 (альтернатива)" | |
| ), | |
| ) | |
| retrieval_button = gr.Button("Поиск") | |
| retrieval_output_text = gr.Textbox( | |
| label="Результат", | |
| ) | |
| retrieval_output_image = gr.Image( | |
| label="Наиболее подходящее изображение", | |
| ) | |
| retrieval_button.click( | |
| fn=retrieve_best_image, | |
| inputs=[retrieval_dir, retrieval_query_text, retrieval_clip_selector], | |
| outputs=[retrieval_output_text, retrieval_output_image], | |
| ) | |
| return demo_block | |
| if __name__ == "__main__": | |
| interface_block = build_interface() | |
| interface_block.launch(share=True) |