Spaces:
Running
Running
| import gradio as gr | |
| from transformers import pipeline | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| import numpy as np | |
| import functools | |
| import warnings | |
| import time | |
| import inspect | |
| from datetime import datetime | |
| from collections import OrderedDict | |
| warnings.filterwarnings("ignore") | |
| # LRU кэш для хранения загруженных моделей | |
| class LRUCache: | |
| """LRU (Least Recently Used) кэш для ограничения использования памяти""" | |
| def __init__(self, maxsize=5): | |
| """ | |
| Args: | |
| maxsize: Максимальное количество моделей в кэше | |
| """ | |
| self.cache = OrderedDict() | |
| self.maxsize = maxsize | |
| def get(self, key): | |
| """Получить модель из кэша""" | |
| if key not in self.cache: | |
| return None | |
| # Перемещаем элемент в конец (как недавно использованный) | |
| self.cache.move_to_end(key) | |
| return self.cache[key] | |
| def put(self, key, value): | |
| """Добавить модель в кэш""" | |
| if key in self.cache: | |
| # Если ключ уже есть, обновляем и перемещаем в конец | |
| self.cache.move_to_end(key) | |
| self.cache[key] = value | |
| else: | |
| # Если кэш полон, удаляем самый старый элемент (первый в OrderedDict) | |
| if len(self.cache) >= self.maxsize: | |
| oldest_key = next(iter(self.cache)) | |
| # Освобождаем память от модели | |
| old_value = self.cache.pop(oldest_key) | |
| del old_value | |
| # Также очищаем кэш CUDA если используется GPU | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| self.cache[key] = value | |
| def __contains__(self, key): | |
| """Проверка наличия ключа в кэше""" | |
| return key in self.cache | |
| def __getitem__(self, key): | |
| """Получить элемент через []""" | |
| value = self.get(key) | |
| if value is None: | |
| raise KeyError(key) | |
| return value | |
| def __setitem__(self, key, value): | |
| """Установить элемент через []""" | |
| self.put(key, value) | |
| def clear(self): | |
| """Очистить кэш""" | |
| self.cache.clear() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def size(self): | |
| """Текущий размер кэша""" | |
| return len(self.cache) | |
| # Создаем LRU кэш с максимальным размером 5 моделей | |
| # Можно изменить это значение в зависимости от доступной памяти | |
| model_cache = LRUCache(maxsize=2) | |
| # История выполнения моделей | |
| history = [] | |
| MAX_HISTORY_SIZE = 50 | |
| def get_pipeline(task, model_name, **kwargs): | |
| """Загрузка pipeline с LRU кэшированием""" | |
| cache_key = f"{task}_{model_name}" | |
| cached_model = model_cache.get(cache_key) | |
| if cached_model is None: | |
| try: | |
| cached_model = pipeline(task, model=model_name, **kwargs) | |
| model_cache.put(cache_key, cached_model) | |
| except Exception as e: | |
| raise Exception(f"Ошибка загрузки модели: {str(e)}") | |
| return cached_model | |
| def measure_time_and_save(task_name): | |
| """Декоратор для измерения времени выполнения и сохранения в историю""" | |
| def decorator(func): | |
| def wrapper(*args, **kwargs): | |
| start_time = time.time() | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| # Извлекаем model_name из аргументов | |
| model_name = None | |
| sig = inspect.signature(func) | |
| bound_args = sig.bind(*args, **kwargs) | |
| bound_args.apply_defaults() | |
| if 'model_name' in bound_args.arguments: | |
| model_name = bound_args.arguments['model_name'] | |
| # Создаем краткое описание входных данных | |
| input_preview = "" | |
| if args and len(args) > 0: | |
| first_arg = args[0] | |
| if isinstance(first_arg, str): | |
| input_preview = first_arg[:100] + ("..." if len(first_arg) > 100 else "") | |
| elif isinstance(first_arg, Image.Image): | |
| input_preview = f"Изображение ({first_arg.size[0]}x{first_arg.size[1]})" | |
| elif isinstance(first_arg, (tuple, list)) and len(first_arg) == 2: | |
| # Аудио файл (sample_rate, audio_data) | |
| input_preview = f"Аудио файл" | |
| else: | |
| input_preview = str(type(first_arg).__name__) | |
| # Выполняем функцию | |
| try: | |
| result = func(*args, **kwargs) | |
| if isinstance(result, str): | |
| output = result | |
| elif isinstance(result, tuple) and len(result) == 2: | |
| # Проверяем тип второго элемента | |
| if isinstance(result[1], Image.Image): | |
| # Результат с изображением (текст, изображение) | |
| output = result[0] if isinstance(result[0], str) else str(result[0])[:500] | |
| elif isinstance(result[1], (tuple, list)) and len(result[1]) == 2: | |
| # Аудио результат (sample_rate, audio_data) | |
| output = f"Аудио файл сгенерирован (sample_rate: {result[0]})" | |
| else: | |
| output = str(result)[:500] | |
| else: | |
| output = str(result)[:500] | |
| except Exception as e: | |
| output = f"Ошибка: {str(e)}" | |
| result = output | |
| # Измеряем время выполнения | |
| execution_time = time.time() - start_time | |
| # Сохраняем в историю | |
| history_entry = { | |
| "task_name": task_name, | |
| "model_name": model_name or "Не указана", | |
| "input_preview": input_preview, | |
| "output": output, | |
| "execution_time": round(execution_time, 3), | |
| "timestamp": timestamp | |
| } | |
| history.insert(0, history_entry) # Добавляем в начало | |
| # Ограничиваем размер истории | |
| if len(history) > MAX_HISTORY_SIZE: | |
| history.pop() | |
| return result | |
| return wrapper | |
| return decorator | |
| # ==================== ТЕКСТОВЫЕ ЗАДАЧИ ==================== | |
| def text_classifier(text, model_name): | |
| """Классификация текста""" | |
| try: | |
| classifier = get_pipeline("text-classification", model_name) | |
| result = classifier(text) | |
| if isinstance(result, list): | |
| result = result[0] | |
| return f"Метка: {result['label']}\nУверенность: {result['score']:.4f}" | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| def zero_shot_classifier(text, candidate_labels, model_name): | |
| """Zero-shot классификация""" | |
| try: | |
| classifier = get_pipeline("zero-shot-classification", model_name) | |
| labels = [label.strip() for label in candidate_labels.split(",")] | |
| result = classifier(text, labels) | |
| output = "Результаты классификации:\n" | |
| for label, score in zip(result['labels'], result['scores']): | |
| output += f"{label}: {score:.4f}\n" | |
| return output | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| def text_generator(prompt, max_length, model_name): | |
| """Генерация текста""" | |
| try: | |
| generator = get_pipeline("text-generation", model_name) | |
| result = generator(prompt, max_length=max_length, num_return_sequences=1, do_sample=True) | |
| return result[0]['generated_text'] | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| def text_unmasker(text, model_name): | |
| """Заполнение пропусков в тексте""" | |
| try: | |
| unmasker = get_pipeline("fill-mask", model_name) | |
| result = unmasker(text) | |
| output = "Варианты заполнения:\n" | |
| for i, item in enumerate(result[:5], 1): | |
| output += f"{i}. {item['sequence']} (уверенность: {item['score']:.4f})\n" | |
| return output | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| def ner_extractor(text, model_name): | |
| """Извлечение именованных сущностей""" | |
| try: | |
| ner = get_pipeline("ner", model_name, aggregation_strategy="simple") | |
| result = ner(text) | |
| if not result: | |
| return "Именованные сущности не найдены" | |
| output = "Найденные сущности:\n" | |
| for entity in result: | |
| output += f"{entity['word']}: {entity['entity_group']} (уверенность: {entity['score']:.4f})\n" | |
| return output | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| def question_answerer(question, context, model_name): | |
| """Ответ на вопрос по контексту""" | |
| try: | |
| qa = get_pipeline("question-answering", model_name) | |
| result = qa(question=question, context=context) | |
| return f"Ответ: {result['answer']}\nУверенность: {result['score']:.4f}" | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| def summarizer(text, max_length, min_length, model_name): | |
| """Суммаризация текста""" | |
| try: | |
| summarizer_pipe = get_pipeline("summarization", model_name) | |
| result = summarizer_pipe(text, max_length=max_length, min_length=min_length) | |
| if isinstance(result, list): | |
| result = result[0] | |
| return result['summary_text'] | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| def translator(text, model_name, src_lang=None, tgt_lang=None): | |
| """Перевод текста""" | |
| try: | |
| translator_pipe = get_pipeline("translation", model_name) | |
| # Для mBART моделей требуются src_lang и tgt_lang | |
| if "mbart" in model_name.lower(): | |
| if not src_lang or not tgt_lang: | |
| return "Ошибка: Для модели mBART необходимо указать исходный и целевой языки" | |
| result = translator_pipe(text, src_lang=src_lang, tgt_lang=tgt_lang) | |
| else: | |
| result = translator_pipe(text) | |
| if isinstance(result, list): | |
| result = result[0] | |
| return result['translation_text'] | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| # ==================== АУДИО ЗАДАЧИ ==================== | |
| def audio_classifier(audio, model_name): | |
| """Классификация аудио""" | |
| try: | |
| classifier = get_pipeline("audio-classification", model_name) | |
| result = classifier(audio) | |
| # audio-classification pipeline возвращает список словарей | |
| if not isinstance(result, list): | |
| result = [result] | |
| output = "Результаты классификации:\n" | |
| for item in result[:5]: | |
| output += f"{item['label']}: {item['score']:.4f}\n" | |
| return output | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| def audio_zero_shot_classifier(audio, candidate_labels, model_name): | |
| """Zero-shot классификация аудио""" | |
| try: | |
| classifier = get_pipeline("zero-shot-audio-classification", model_name) | |
| labels = [label.strip() for label in candidate_labels.split(",")] | |
| result = classifier(audio, candidate_labels=labels) | |
| output = "Результаты классификации:\n" | |
| # Парсинг результата: [{'score': ..., 'label': ...}, ...] | |
| if isinstance(result, list): | |
| for item in result: | |
| if isinstance(item, dict) and 'label' in item and 'score' in item: | |
| output += f"{item['label']}: {item['score']:.4f}\n" | |
| else: | |
| return f"Ошибка: Неожиданный формат результата от pipeline: {type(result)}. Ожидался список словарей с ключами 'label' и 'score'." | |
| return output | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "Could not load model" in error_msg or "Unrecognized" in error_msg: | |
| return f"Ошибка: Модель '{model_name}' не поддерживается для zero-shot классификации аудио. Попробуйте другую модель, например 'laion/clap-htsat-unfused'." | |
| return f"Ошибка: {error_msg}" | |
| def speech_recognition(audio, model_name): | |
| """Распознавание речи""" | |
| try: | |
| asr = get_pipeline("automatic-speech-recognition", model_name) | |
| result = asr(audio) | |
| if isinstance(result, list): | |
| result = result[0] | |
| return result['text'] | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| def speech_synthesis(text, model_name): | |
| """Синтез речи""" | |
| try: | |
| import numpy as np | |
| import torch | |
| # Проверяем, что текст не пустой | |
| if not text or not text.strip(): | |
| raise ValueError("Текст для синтеза не может быть пустым") | |
| # Проверяем, является ли модель SpeechT5 | |
| if "speecht5" in model_name.lower(): | |
| try: | |
| # Для SpeechT5 нужны speaker_embeddings | |
| # Пробуем использовать pipeline с forward_params | |
| try: | |
| tts = get_pipeline("text-to-speech", model_name) | |
| # Пытаемся загрузить предобученные speaker embeddings из датасета | |
| try: | |
| from datasets import load_dataset | |
| embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
| speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) | |
| except Exception: | |
| # Если не удалось загрузить, создаем случайный speaker embedding | |
| # Сначала загружаем модель чтобы узнать размерность | |
| from transformers import SpeechT5ForTextToSpeech | |
| temp_model = SpeechT5ForTextToSpeech.from_pretrained(model_name) | |
| speaker_embedding_dim = temp_model.config.speaker_embedding_dim | |
| del temp_model | |
| # Создаем случайный speaker embedding | |
| speaker_embedding = torch.randn(1, speaker_embedding_dim) | |
| speaker_embedding = speaker_embedding / torch.norm(speaker_embedding, dim=1, keepdim=True) | |
| # Используем pipeline с forward_params | |
| result = tts(text, forward_params={"speaker_embeddings": speaker_embedding}) | |
| # Обрабатываем результат | |
| if isinstance(result, dict): | |
| audio_data = result.get("audio", result.get("raw", None)) | |
| sample_rate = result.get("sampling_rate", result.get("sample_rate", 16000)) | |
| if audio_data is None: | |
| raise ValueError("Не удалось извлечь аудио данные из результата pipeline") | |
| # Конвертируем в numpy array если нужно | |
| if isinstance(audio_data, torch.Tensor): | |
| audio_data = audio_data.numpy() | |
| elif not isinstance(audio_data, np.ndarray): | |
| audio_data = np.array(audio_data) | |
| # Убеждаемся, что это 1D массив | |
| if len(audio_data.shape) > 1: | |
| audio_data = audio_data.flatten() | |
| # Нормализуем в float32 | |
| if audio_data.dtype != np.float32: | |
| audio_data = audio_data.astype(np.float32) | |
| # Нормализуем если значения выходят за пределы [-1, 1] | |
| max_val = np.abs(audio_data).max() | |
| if max_val > 1.0: | |
| audio_data = audio_data / max_val | |
| return (sample_rate, audio_data) | |
| else: | |
| raise ValueError(f"Неожиданный формат результата от pipeline: {type(result)}") | |
| except Exception as pipeline_error: | |
| # Если pipeline не работает, пробуем напрямую через модель | |
| from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
| cache_key = f"tts_speecht5_{model_name}" | |
| cached = model_cache.get(cache_key) | |
| if cached is None: | |
| processor = SpeechT5Processor.from_pretrained(model_name) | |
| model = SpeechT5ForTextToSpeech.from_pretrained(model_name) | |
| vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") | |
| # Пытаемся загрузить предобученные speaker embeddings | |
| try: | |
| from datasets import load_dataset | |
| embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
| speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) | |
| except Exception: | |
| # Если не удалось, создаем случайный | |
| speaker_embedding_dim = model.config.speaker_embedding_dim | |
| speaker_embeddings = torch.randn(1, speaker_embedding_dim) | |
| speaker_embeddings = speaker_embeddings / torch.norm(speaker_embeddings, dim=1, keepdim=True) | |
| cached = (processor, model, vocoder, speaker_embeddings) | |
| model_cache.put(cache_key, cached) | |
| processor, model, vocoder, speaker_embeddings = cached | |
| inputs = processor(text=text, return_tensors="pt") | |
| # Убеждаемся, что speaker_embeddings имеют правильную форму и тип | |
| if not isinstance(speaker_embeddings, torch.Tensor): | |
| speaker_embeddings = torch.tensor(speaker_embeddings) | |
| if speaker_embeddings.device != inputs["input_ids"].device: | |
| speaker_embeddings = speaker_embeddings.to(inputs["input_ids"].device) | |
| with torch.no_grad(): | |
| # Вызываем generate_speech с именованным параметром | |
| speech = model.generate_speech( | |
| inputs["input_ids"], | |
| speaker_embeddings=speaker_embeddings, | |
| vocoder=vocoder | |
| ) | |
| # Конвертируем в numpy и нормализуем | |
| audio_data = speech.numpy() | |
| # Убеждаемся, что это 1D массив | |
| if len(audio_data.shape) > 1: | |
| audio_data = audio_data.flatten() | |
| # Нормализуем в диапазон [-1, 1] если нужно | |
| if audio_data.dtype != np.float32: | |
| audio_data = audio_data.astype(np.float32) | |
| # Нормализуем если значения выходят за пределы [-1, 1] | |
| max_val = np.abs(audio_data).max() | |
| if max_val > 1.0: | |
| audio_data = audio_data / max_val | |
| sample_rate = 16000 | |
| return (sample_rate, audio_data) | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "ImportError" in str(type(e)) or "ModuleNotFoundError" in str(type(e)): | |
| raise Exception(f"Ошибка: Не удалось импортировать необходимые модули для SpeechT5. Убедитесь, что transformers установлен: {error_msg}") | |
| raise Exception(f"Ошибка синтеза речи с SpeechT5: {error_msg}") | |
| # Используем стандартный pipeline для других моделей | |
| tts = get_pipeline("text-to-speech", model_name) | |
| result = tts(text) | |
| # Pipeline может возвращать словарь или кортеж | |
| if isinstance(result, dict): | |
| # Стандартный формат: {"audio": array, "sampling_rate": int} | |
| audio_data = result.get("audio", result.get("raw", None)) | |
| sample_rate = result.get("sampling_rate", result.get("sample_rate", 22050)) | |
| if audio_data is None: | |
| raise ValueError("Не удалось извлечь аудио данные из результата pipeline") | |
| # Конвертируем в numpy array если нужно | |
| if isinstance(audio_data, torch.Tensor): | |
| audio_data = audio_data.numpy() | |
| elif not isinstance(audio_data, np.ndarray): | |
| audio_data = np.array(audio_data) | |
| # Убеждаемся, что это 1D массив | |
| if len(audio_data.shape) > 1: | |
| audio_data = audio_data.flatten() | |
| # Нормализуем в float32 | |
| if audio_data.dtype != np.float32: | |
| audio_data = audio_data.astype(np.float32) | |
| # Нормализуем если значения выходят за пределы [-1, 1] | |
| max_val = np.abs(audio_data).max() | |
| if max_val > 1.0: | |
| audio_data = audio_data / max_val | |
| return (sample_rate, audio_data) | |
| elif isinstance(result, tuple) and len(result) == 2: | |
| # Уже в правильном формате (sample_rate, audio_data) | |
| sample_rate, audio_data = result | |
| # Конвертируем в numpy если нужно | |
| if isinstance(audio_data, torch.Tensor): | |
| audio_data = audio_data.numpy() | |
| elif not isinstance(audio_data, np.ndarray): | |
| audio_data = np.array(audio_data) | |
| # Убеждаемся, что это 1D массив | |
| if len(audio_data.shape) > 1: | |
| audio_data = audio_data.flatten() | |
| # Нормализуем в float32 | |
| if audio_data.dtype != np.float32: | |
| audio_data = audio_data.astype(np.float32) | |
| # Нормализуем если значения выходят за пределы [-1, 1] | |
| max_val = np.abs(audio_data).max() | |
| if max_val > 1.0: | |
| audio_data = audio_data / max_val | |
| return (sample_rate, audio_data) | |
| else: | |
| raise ValueError(f"Неожиданный формат результата от pipeline: {type(result)}. Ожидался словарь с ключами 'audio' и 'sampling_rate' или кортеж (sample_rate, audio_data).") | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "speaker_embeddings" in error_msg.lower(): | |
| if "speecht5" in model_name.lower(): | |
| return f"Ошибка: Модель SpeechT5 требует speaker_embeddings. Они должны генерироваться автоматически, но произошла ошибка: {error_msg}" | |
| return f"Ошибка: Модель '{model_name}' требует speaker_embeddings. Для SpeechT5 они генерируются автоматически, но для других моделей может потребоваться дополнительная настройка." | |
| if "does not appear to have a file named" in error_msg or "Unrecognized model" in error_msg: | |
| return f"Ошибка: Модель '{model_name}' не поддерживается библиотекой transformers для синтеза речи. Попробуйте использовать модель 'microsoft/speecht5_tts'." | |
| if "negative output size" in error_msg.lower() or "input size 0" in error_msg.lower(): | |
| return f"Ошибка: Проблема с обработкой текста моделью '{model_name}'. Возможные причины: неподдерживаемый язык, пустой текст после обработки, или проблема с токенизацией. Попробуйте использовать другой текст или модель." | |
| raise Exception(f"Ошибка синтеза речи: {error_msg}") | |
| # ==================== ЗАДАЧИ С ИЗОБРАЖЕНИЯМИ ==================== | |
| def object_detection(image, model_name): | |
| """Обнаружение объектов на изображении""" | |
| try: | |
| detector = get_pipeline("object-detection", model_name) | |
| result = detector(image) | |
| # Создаем копию изображения для визуализации | |
| img_with_boxes = image.copy() | |
| draw = ImageDraw.Draw(img_with_boxes) | |
| # Цвета для разных объектов | |
| colors = ['red', 'blue', 'green', 'yellow', 'orange', 'purple', 'cyan', 'magenta'] | |
| output = "Обнаруженные объекты:\n" | |
| for i, item in enumerate(result): | |
| box = item['box'] | |
| label = item['label'] | |
| score = item['score'] | |
| # Обрабатываем различные форматы координат | |
| if isinstance(box, dict): | |
| # Словарь с ключами 'xmin', 'ymin', 'xmax', 'ymax' | |
| xmin = box.get('xmin', box.get('x1', 0)) | |
| ymin = box.get('ymin', box.get('y1', 0)) | |
| xmax = box.get('xmax', box.get('x2', 0)) | |
| ymax = box.get('ymax', box.get('y2', 0)) | |
| elif isinstance(box, (list, tuple)) and len(box) >= 4: | |
| # Список [xmin, ymin, xmax, ymax] или [xcenter, ycenter, width, height] | |
| if box[2] > box[0] and box[3] > box[1]: | |
| # Вероятно [xmin, ymin, xmax, ymax] | |
| xmin, ymin, xmax, ymax = box[0], box[1], box[2], box[3] | |
| else: | |
| # Вероятно [xcenter, ycenter, width, height] | |
| xcenter, ycenter, width, height = box[0], box[1], box[2], box[3] | |
| xmin = xcenter - width / 2 | |
| ymin = ycenter - height / 2 | |
| xmax = xcenter + width / 2 | |
| ymax = ycenter + height / 2 | |
| else: | |
| # Неизвестный формат, пропускаем | |
| output += f"{label}: уверенность {score:.4f}, координаты {box}\n" | |
| continue | |
| # Проверяем и ограничиваем координаты границами изображения | |
| img_width, img_height = img_with_boxes.size | |
| xmin = max(0, min(xmin, img_width)) | |
| ymin = max(0, min(ymin, img_height)) | |
| xmax = max(0, min(xmax, img_width)) | |
| ymax = max(0, min(ymax, img_height)) | |
| # Проверяем, что координаты валидны | |
| if xmax <= xmin or ymax <= ymin: | |
| output += f"{label}: уверенность {score:.4f}, координаты {box} (некорректные)\n" | |
| continue | |
| # Рисуем прямоугольник | |
| color = colors[i % len(colors)] | |
| draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=3) | |
| # Добавляем текст с меткой и уверенностью | |
| text = f"{label}: {score:.2f}" | |
| try: | |
| # Пытаемся использовать системный шрифт | |
| font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 16) | |
| except: | |
| try: | |
| font = ImageFont.load_default() | |
| except: | |
| font = None | |
| # Получаем размер текста | |
| if font: | |
| bbox = draw.textbbox((0, 0), text, font=font) | |
| text_width = bbox[2] - bbox[0] | |
| text_height = bbox[3] - bbox[1] | |
| else: | |
| text_width = len(text) * 6 | |
| text_height = 12 | |
| # Рисуем фон для текста (проверяем границы) | |
| text_y = max(0, ymin - text_height - 4) | |
| text_x_end = min(img_width, xmin + text_width + 4) | |
| draw.rectangle([xmin, text_y, text_x_end, ymin], fill=color) | |
| draw.text((xmin + 2, text_y + 2), text, fill='white', font=font) | |
| output += f"{label}: уверенность {score:.4f}, координаты {box}\n" | |
| return output, img_with_boxes | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}", image | |
| def image_segmentation(image, model_name): | |
| """Сегментация изображения""" | |
| try: | |
| segmenter = get_pipeline("image-segmentation", model_name) | |
| result = segmenter(image) | |
| # Создаем копию изображения для визуализации | |
| img_with_segments = image.copy().convert("RGBA") | |
| # Генерируем цвета для сегментов | |
| np.random.seed(42) # Для воспроизводимости | |
| output = "Сегменты:\n" | |
| overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) | |
| draw = ImageDraw.Draw(overlay) | |
| # Список для хранения информации о сегментах (для добавления текста) | |
| segments_info = [] | |
| for i, item in enumerate(result): | |
| label = item['label'] | |
| score = item['score'] | |
| # Генерируем полупрозрачный цвет для сегмента | |
| color = tuple(np.random.randint(0, 255, 3)) + (128,) # RGBA с прозрачностью | |
| color_rgb = color[:3] # RGB цвет для текста | |
| # Проверяем наличие маски | |
| if 'mask' in item: | |
| mask = item['mask'] | |
| # Преобразуем маску в numpy array | |
| if isinstance(mask, Image.Image): | |
| mask_array = np.array(mask) | |
| elif isinstance(mask, np.ndarray): | |
| mask_array = mask | |
| else: | |
| mask_array = np.array(mask) | |
| # Нормализуем маску, если нужно | |
| if mask_array.dtype != np.uint8: | |
| if mask_array.max() <= 1.0: | |
| mask_array = (mask_array * 255).astype(np.uint8) | |
| else: | |
| mask_array = mask_array.astype(np.uint8) | |
| # Находим центр маски для размещения текста | |
| if len(mask_array.shape) == 2: # Grayscale mask | |
| mask_bool = mask_array > 0 | |
| elif len(mask_array.shape) == 3 and mask_array.shape[2] == 1: | |
| mask_bool = mask_array[:, :, 0] > 0 | |
| else: | |
| if mask_array.shape[2] >= 1: | |
| mask_bool = mask_array[:, :, 0] > 0 | |
| else: | |
| mask_bool = np.zeros(mask_array.shape[:2], dtype=bool) | |
| # Вычисляем центр маски | |
| if np.any(mask_bool): | |
| y_coords, x_coords = np.where(mask_bool) | |
| if len(y_coords) > 0 and len(x_coords) > 0: | |
| center_y = int(np.mean(y_coords)) | |
| center_x = int(np.mean(x_coords)) | |
| # Масштабируем координаты, если маска другого размера | |
| if mask_array.shape[:2] != image.size[::-1]: | |
| scale_y = image.size[1] / mask_array.shape[0] | |
| scale_x = image.size[0] / mask_array.shape[1] | |
| center_y = int(center_y * scale_y) | |
| center_x = int(center_x * scale_x) | |
| segments_info.append({ | |
| 'label': label, | |
| 'score': score, | |
| 'center': (center_x, center_y), | |
| 'color': color_rgb | |
| }) | |
| # Создаем цветную маску | |
| if len(mask_array.shape) == 2: # Grayscale mask | |
| # Создаем RGBA маску | |
| colored_mask = np.zeros((mask_array.shape[0], mask_array.shape[1], 4), dtype=np.uint8) | |
| # Применяем цвет только там, где маска не равна нулю | |
| mask_bool = mask_array > 0 | |
| colored_mask[mask_bool, :3] = color[:3] | |
| colored_mask[mask_bool, 3] = 128 # Альфа-канал | |
| elif len(mask_array.shape) == 3 and mask_array.shape[2] == 1: | |
| # Маска с одним каналом | |
| colored_mask = np.zeros((mask_array.shape[0], mask_array.shape[1], 4), dtype=np.uint8) | |
| mask_bool = mask_array[:, :, 0] > 0 | |
| colored_mask[mask_bool, :3] = color[:3] | |
| colored_mask[mask_bool, 3] = 128 | |
| else: | |
| # Многоканальная маска | |
| colored_mask = np.zeros((mask_array.shape[0], mask_array.shape[1], 4), dtype=np.uint8) | |
| # Используем первый канал как маску | |
| if mask_array.shape[2] >= 1: | |
| mask_bool = mask_array[:, :, 0] > 0 | |
| colored_mask[mask_bool, :3] = color[:3] | |
| colored_mask[mask_bool, 3] = 128 | |
| # Убеждаемся, что размеры совпадают | |
| if colored_mask.shape[:2] == img_with_segments.size[::-1]: | |
| mask_img = Image.fromarray(colored_mask, mode='RGBA') | |
| overlay = Image.alpha_composite(overlay, mask_img) | |
| elif colored_mask.shape[:2] != overlay.size[::-1]: | |
| # Изменяем размер маски, если нужно | |
| mask_img = Image.fromarray(colored_mask, mode='RGBA') | |
| mask_img = mask_img.resize(overlay.size, Image.Resampling.LANCZOS) | |
| overlay = Image.alpha_composite(overlay, mask_img) | |
| output += f"{label}: уверенность {score:.4f}\n" | |
| # Накладываем overlay на исходное изображение | |
| if overlay.size == img_with_segments.size: | |
| img_with_segments = Image.alpha_composite(img_with_segments, overlay) | |
| # Добавляем текстовые метки с цветами на изображение | |
| draw_final = ImageDraw.Draw(img_with_segments) | |
| # Загружаем шрифт | |
| try: | |
| font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 18) | |
| except: | |
| try: | |
| font = ImageFont.load_default() | |
| except: | |
| font = None | |
| for seg_info in segments_info: | |
| label = seg_info['label'] | |
| score = seg_info['score'] | |
| center_x, center_y = seg_info['center'] | |
| color_rgb = seg_info['color'] | |
| # Формируем текст метки | |
| text = f"{label}: {score:.2f}" | |
| # Получаем размер текста | |
| if font: | |
| bbox = draw_final.textbbox((0, 0), text, font=font) | |
| text_width = bbox[2] - bbox[0] | |
| text_height = bbox[3] - bbox[1] | |
| else: | |
| text_width = len(text) * 7 | |
| text_height = 14 | |
| # Вычисляем позицию текста (центрируем относительно центра сегмента) | |
| text_x = center_x - text_width // 2 | |
| text_y = center_y - text_height // 2 | |
| # Ограничиваем координаты границами изображения | |
| img_width, img_height = img_with_segments.size | |
| text_x = max(2, min(text_x, img_width - text_width - 2)) | |
| text_y = max(2, min(text_y, img_height - text_height - 2)) | |
| # Рисуем фон для текста (полупрозрачный черный для читаемости) | |
| padding = 4 | |
| draw_final.rectangle( | |
| [text_x - padding, text_y - padding, | |
| text_x + text_width + padding, text_y + text_height + padding], | |
| fill=(0, 0, 0, 180) # Полупрозрачный черный фон | |
| ) | |
| # Рисуем текст цветом сегмента | |
| draw_final.text( | |
| (text_x, text_y), | |
| text, | |
| fill=color_rgb + (255,), # RGB + альфа для RGBA | |
| font=font | |
| ) | |
| # Конвертируем обратно в RGB для отображения | |
| img_with_segments = img_with_segments.convert("RGB") | |
| return output, img_with_segments | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}", image | |
| def image_captioning(image, model_name): | |
| """Описание изображения""" | |
| try: | |
| captioner = get_pipeline("image-to-text", model_name) | |
| result = captioner(image) | |
| if isinstance(result, list): | |
| result = result[0] | |
| return result['generated_text'] | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| def visual_qa(image, question, model_name): | |
| """Визуальный вопрос-ответ""" | |
| try: | |
| vqa = get_pipeline("visual-question-answering", model_name) | |
| result = vqa(image=image, question=question) | |
| if isinstance(result, list): | |
| result = result[0] | |
| return f"Ответ: {result['answer']}" | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| def image_zero_shot_classification(image, candidate_labels, model_name): | |
| """Zero-shot классификация изображений""" | |
| try: | |
| labels = [label.strip() for label in candidate_labels.split(",")] | |
| # Проверяем, является ли модель LAION (проверяем ДО вызова get_pipeline) | |
| model_name_lower = model_name.lower() | |
| if "laion/" in model_name_lower or "laion5b" in model_name_lower or "laion" in model_name_lower: | |
| # Используем OpenCLIP для LAION моделей | |
| try: | |
| import open_clip | |
| except ImportError: | |
| return f"Ошибка: Для работы с LAION моделями требуется библиотека open-clip-torch. Установите её: pip install open-clip-torch" | |
| cache_key = f"clip_laion_{model_name}" | |
| cached = model_cache.get(cache_key) | |
| if cached is None: | |
| # Определяем имя модели и веса для OpenCLIP | |
| if "xlm-roberta-base-ViT-B-32" in model_name or "xlm-roberta-base" in model_name: | |
| clip_model_name = "xlm-roberta-base-ViT-B-32" | |
| pretrained = "laion5b_s13b_b90k" | |
| else: | |
| # Пытаемся извлечь информацию из имени модели | |
| clip_model_name = "xlm-roberta-base-ViT-B-32" | |
| pretrained = "laion5b_s13b_b90k" | |
| model, _, preprocess = open_clip.create_model_and_transforms( | |
| clip_model_name, | |
| pretrained=pretrained | |
| ) | |
| tokenizer = open_clip.get_tokenizer(clip_model_name) | |
| model.eval() | |
| cached = (model, preprocess, tokenizer) | |
| model_cache.put(cache_key, cached) | |
| model, preprocess, tokenizer = cached | |
| # Обрабатываем изображение и тексты | |
| image_tensor = preprocess(image).unsqueeze(0) | |
| text_tokens = tokenizer(labels) | |
| with torch.no_grad(): | |
| image_features = model.encode_image(image_tensor) | |
| text_features = model.encode_text(text_tokens) | |
| # Нормализуем признаки | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| # Вычисляем косинусное сходство (логиты) | |
| logits_per_image = (image_features @ text_features.T) * 100 # Масштабируем для лучшей точности | |
| probs = logits_per_image.softmax(dim=1) | |
| output = "Результаты классификации:\n" | |
| for label, prob in zip(labels, probs[0]): | |
| output += f"{label}: {prob.item():.4f}\n" | |
| return output | |
| else: | |
| # Используем стандартный pipeline | |
| classifier = get_pipeline("zero-shot-image-classification", model_name) | |
| result = classifier(image, candidate_labels=labels) | |
| output = "Результаты классификации:\n" | |
| # Парсим результат | |
| if isinstance(result, list): | |
| # Результат - список словарей с 'score' и 'label' | |
| for item in result: | |
| if isinstance(item, dict) and 'label' in item and 'score' in item: | |
| output += f"{item['label']}: {item['score']:.4f}\n" | |
| else: | |
| return f"Ошибка: Неожиданный формат элемента в результате: {item}" | |
| elif isinstance(result, dict): | |
| # Результат - словарь с 'labels' и 'scores' | |
| if 'labels' in result and 'scores' in result: | |
| for label, score in zip(result['labels'], result['scores']): | |
| output += f"{label}: {score:.4f}\n" | |
| else: | |
| return f"Ошибка: Неожиданный формат результата от pipeline: {result}. Ожидался словарь с ключами 'labels' и 'scores' или список словарей." | |
| else: | |
| return f"Ошибка: Неожиданный формат результата от pipeline: {type(result)}. Ожидался словарь с ключами 'labels' и 'scores' или список словарей. Получен: {result}" | |
| return output | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "Could not load model" in error_msg or "Unrecognized" in error_msg: | |
| if "laion" in model_name.lower(): | |
| return f"Ошибка: Модель '{model_name}' требует библиотеку open-clip-torch. Убедитесь, что она установлена: pip install open-clip-torch" | |
| return f"Ошибка: Модель '{model_name}' не поддерживается для zero-shot классификации изображений. Попробуйте другую модель, например 'openai/clip-vit-base-patch32'." | |
| if "open_clip" in error_msg or "open-clip" in error_msg or "ModuleNotFoundError" in str(type(e)): | |
| return f"Ошибка: Для работы с LAION моделями требуется библиотека open-clip-torch. Установите её: pip install open-clip-torch" | |
| return f"Ошибка: {error_msg}" | |
| # ==================== ФУНКЦИИ ДЛЯ ИСТОРИИ ==================== | |
| def get_history_display(): | |
| """Форматирует историю для отображения в таблице""" | |
| if not history: | |
| return [] | |
| # Форматируем данные для таблицы | |
| display_data = [] | |
| for entry in history: | |
| # Обрезаем длинные результаты для отображения | |
| output_preview = entry['output'] | |
| if len(output_preview) > 100: | |
| output_preview = output_preview[:100] + "..." | |
| display_data.append([ | |
| entry['task_name'], | |
| entry['model_name'], | |
| f"{entry['execution_time']} сек", | |
| entry['timestamp'], | |
| entry['input_preview'], | |
| output_preview | |
| ]) | |
| return display_data | |
| def clear_history(): | |
| """Очищает историю""" | |
| global history | |
| history.clear() | |
| return get_history_display(), [] | |
| # ==================== GRADIO ИНТЕРФЕЙС ==================== | |
| with gr.Blocks(title="Трансформеры Hugging Face", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🤖 Трансформеры Hugging Face") | |
| gr.Markdown("Выберите вкладку для работы с различными типами трансформеров") | |
| with gr.Tabs(): | |
| # ========== ВКЛАДКА 1: ТЕКСТОВЫЕ ТРАНСФОРМЕРЫ ========== | |
| with gr.Tab("📝 Текстовые трансформеры"): | |
| gr.Markdown("## Работа с текстовыми моделями") | |
| with gr.Accordion("Классификатор", open=True): | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_classifier_input = gr.Textbox( | |
| label="Введите текст для классификации", | |
| placeholder="I love this app!", | |
| value="I love this app!" | |
| ) | |
| text_classifier_model = gr.Dropdown( | |
| choices=[ | |
| "distilbert-base-uncased-finetuned-sst-2-english", | |
| "cardiffnlp/twitter-roberta-base-sentiment-latest", | |
| "nlptown/bert-base-multilingual-uncased-sentiment" | |
| ], | |
| value="distilbert-base-uncased-finetuned-sst-2-english", | |
| label="Выберите модель" | |
| ) | |
| text_classifier_btn = gr.Button("Классифицировать", variant="primary") | |
| with gr.Column(): | |
| text_classifier_output = gr.Textbox(label="Результат", lines=8) | |
| text_classifier_btn.click( | |
| text_classifier, | |
| inputs=[text_classifier_input, text_classifier_model], | |
| outputs=text_classifier_output | |
| ) | |
| with gr.Accordion("Zero-shot классификатор", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| zs_text_input = gr.Textbox( | |
| label="Введите текст", | |
| placeholder="I just finished reading a great book", | |
| value="I just finished reading a great book" | |
| ) | |
| zs_text_labels = gr.Textbox( | |
| label="Кандидаты (через запятую)", | |
| placeholder="positive, negative, neutral", | |
| value="positive, negative, neutral" | |
| ) | |
| zs_text_model = gr.Dropdown( | |
| choices=[ | |
| "facebook/bart-large-mnli", | |
| "typeform/distilbert-base-uncased-mnli", | |
| "valhalla/distilbart-mnli-12-3" | |
| ], | |
| value="facebook/bart-large-mnli", | |
| label="Выберите модель" | |
| ) | |
| zs_text_btn = gr.Button("Классифицировать", variant="primary") | |
| with gr.Column(): | |
| zs_text_output = gr.Textbox(label="Результат", lines=8) | |
| zs_text_btn.click( | |
| zero_shot_classifier, | |
| inputs=[zs_text_input, zs_text_labels, zs_text_model], | |
| outputs=zs_text_output | |
| ) | |
| with gr.Accordion("Генератор текста", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_gen_input = gr.Textbox( | |
| label="Промпт", | |
| placeholder="In the distant future", | |
| value="In the distant future" | |
| ) | |
| text_gen_length = gr.Slider(20, 200, value=50, step=10, label="Максимальная длина") | |
| text_gen_model = gr.Dropdown( | |
| choices=["gpt2", "distilgpt2", "EleutherAI/gpt-neo-125M"], | |
| value="gpt2", | |
| label="Выберите модель" | |
| ) | |
| text_gen_btn = gr.Button("Сгенерировать", variant="primary") | |
| with gr.Column(): | |
| text_gen_output = gr.Textbox(label="Сгенерированный текст", lines=12) | |
| text_gen_btn.click( | |
| text_generator, | |
| inputs=[text_gen_input, text_gen_length, text_gen_model], | |
| outputs=text_gen_output | |
| ) | |
| with gr.Accordion("Unmasker (заполнение пропусков)", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| unmasker_input = gr.Textbox( | |
| label="Текст с [MASK]", | |
| placeholder="I love [MASK] programming", | |
| value="I love [MASK] programming" | |
| ) | |
| unmasker_model = gr.Dropdown( | |
| choices=[ | |
| "bert-base-uncased", | |
| "distilbert-base-uncased", | |
| "bert-base-multilingual-uncased" | |
| ], | |
| value="bert-base-uncased", | |
| label="Выберите модель" | |
| ) | |
| unmasker_btn = gr.Button("Заполнить", variant="primary") | |
| with gr.Column(): | |
| unmasker_output = gr.Textbox(label="Результат", lines=10) | |
| unmasker_btn.click( | |
| text_unmasker, | |
| inputs=[unmasker_input, unmasker_model], | |
| outputs=unmasker_output | |
| ) | |
| with gr.Accordion("NER (Извлечение именованных сущностей)", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| ner_input = gr.Textbox( | |
| label="Введите текст", | |
| placeholder="My name is John, I work at Microsoft in Seattle", | |
| value="My name is John, I work at Microsoft in Seattle" | |
| ) | |
| ner_model = gr.Dropdown( | |
| choices=[ | |
| "dslim/bert-base-NER", | |
| "dbmdz/bert-large-cased-finetuned-conll03-english", | |
| "Jean-Baptiste/roberta-large-ner-english" | |
| ], | |
| value="dslim/bert-base-NER", | |
| label="Выберите модель" | |
| ) | |
| ner_btn = gr.Button("Извлечь сущности", variant="primary") | |
| with gr.Column(): | |
| ner_output = gr.Textbox(label="Найденные сущности", lines=10) | |
| ner_btn.click( | |
| ner_extractor, | |
| inputs=[ner_input, ner_model], | |
| outputs=ner_output | |
| ) | |
| with gr.Accordion("Question Answering", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| qa_question = gr.Textbox( | |
| label="Вопрос", | |
| placeholder="What color is the sky?", | |
| value="What color is the sky?" | |
| ) | |
| qa_context = gr.Textbox( | |
| label="Контекст", | |
| placeholder="The sky is blue due to light scattering", | |
| value="The sky is blue due to light scattering in the atmosphere", | |
| lines=3 | |
| ) | |
| qa_model = gr.Dropdown( | |
| choices=[ | |
| "distilbert-base-uncased-distilled-squad", | |
| "deepset/roberta-base-squad2", | |
| "bert-large-uncased-whole-word-masking-finetuned-squad" | |
| ], | |
| value="distilbert-base-uncased-distilled-squad", | |
| label="Выберите модель" | |
| ) | |
| qa_btn = gr.Button("Ответить", variant="primary") | |
| with gr.Column(): | |
| qa_output = gr.Textbox(label="Ответ", lines=8) | |
| qa_btn.click( | |
| question_answerer, | |
| inputs=[qa_question, qa_context, qa_model], | |
| outputs=qa_output | |
| ) | |
| with gr.Accordion("Суммаризатор", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| summarizer_input = gr.Textbox( | |
| label="Текст для суммаризации", | |
| placeholder="Enter a long text...", | |
| value="Artificial intelligence is a field of computer science that focuses on creating intelligent machines. Machine learning is a subset of artificial intelligence that enables systems to automatically learn and improve from experience.", | |
| lines=5 | |
| ) | |
| summarizer_max = gr.Slider(20, 200, value=50, step=10, label="Максимальная длина") | |
| summarizer_min = gr.Slider(10, 100, value=20, step=10, label="Минимальная длина") | |
| summarizer_model = gr.Dropdown( | |
| choices=[ | |
| "facebook/bart-large-cnn", | |
| "google/pegasus-xsum", | |
| "t5-small" | |
| ], | |
| value="facebook/bart-large-cnn", | |
| label="Выберите модель" | |
| ) | |
| summarizer_btn = gr.Button("Суммаризировать", variant="primary") | |
| with gr.Column(): | |
| summarizer_output = gr.Textbox(label="Краткое содержание", lines=12) | |
| summarizer_btn.click( | |
| summarizer, | |
| inputs=[summarizer_input, summarizer_max, summarizer_min, summarizer_model], | |
| outputs=summarizer_output | |
| ) | |
| with gr.Accordion("Переводчик", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| translator_input = gr.Textbox( | |
| label="Текст для перевода", | |
| placeholder="Hello, how are you?", | |
| value="Hello, how are you?", | |
| lines=3 | |
| ) | |
| translator_model = gr.Dropdown( | |
| choices=[ | |
| "Helsinki-NLP/opus-mt-en-ru", | |
| "Helsinki-NLP/opus-mt-ru-en", | |
| "facebook/mbart-large-50-many-to-many-mmt" | |
| ], | |
| value="Helsinki-NLP/opus-mt-en-ru", | |
| label="Выберите модель" | |
| ) | |
| translator_src_lang = gr.Dropdown( | |
| choices=[ | |
| "en_XX", "ru_RU", "de_DE", "fr_XX", "es_XX", | |
| "it_IT", "pt_XX", "ja_XX", "ko_KR", "zh_CN", | |
| "ar_AR", "hi_IN", "tr_TR", "vi_VN", "th_TH" | |
| ], | |
| value="en_XX", | |
| label="Исходный язык (для mBART)", | |
| visible=False | |
| ) | |
| translator_tgt_lang = gr.Dropdown( | |
| choices=[ | |
| "en_XX", "ru_RU", "de_DE", "fr_XX", "es_XX", | |
| "it_IT", "pt_XX", "ja_XX", "ko_KR", "zh_CN", | |
| "ar_AR", "hi_IN", "tr_TR", "vi_VN", "th_TH" | |
| ], | |
| value="ru_RU", | |
| label="Целевой язык (для mBART)", | |
| visible=False | |
| ) | |
| translator_btn = gr.Button("Перевести", variant="primary") | |
| with gr.Column(): | |
| translator_output = gr.Textbox(label="Перевод", lines=8) | |
| def update_lang_visibility(model_name): | |
| """Показывает/скрывает поля языков в зависимости от модели""" | |
| is_mbart = "mbart" in model_name.lower() | |
| return gr.update(visible=is_mbart), gr.update(visible=is_mbart) | |
| translator_model.change( | |
| fn=update_lang_visibility, | |
| inputs=translator_model, | |
| outputs=[translator_src_lang, translator_tgt_lang] | |
| ) | |
| translator_btn.click( | |
| translator, | |
| inputs=[translator_input, translator_model, translator_src_lang, translator_tgt_lang], | |
| outputs=translator_output | |
| ) | |
| # ========== ВКЛАДКА 2: АУДИО ТРАНСФОРМЕРЫ ========== | |
| with gr.Tab("🎵 Аудио трансформеры"): | |
| gr.Markdown("## Работа с аудио моделями") | |
| with gr.Accordion("Классификатор", open=True): | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_classifier_input = gr.Audio( | |
| label="Загрузите аудио файл", | |
| type="filepath" | |
| ) | |
| audio_classifier_model = gr.Dropdown( | |
| choices=[ | |
| "MIT/ast-finetuned-audioset-10-10-0.4593", | |
| "ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", | |
| "superb/hubert-base-superb-er" | |
| ], | |
| value="MIT/ast-finetuned-audioset-10-10-0.4593", | |
| label="Выберите модель" | |
| ) | |
| audio_classifier_btn = gr.Button("Классифицировать", variant="primary") | |
| with gr.Column(): | |
| audio_classifier_output = gr.Textbox(label="Результат", lines=6) | |
| audio_classifier_btn.click( | |
| audio_classifier, | |
| inputs=[audio_classifier_input, audio_classifier_model], | |
| outputs=audio_classifier_output | |
| ) | |
| with gr.Accordion("Zero-shot классификатор", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| zs_audio_input = gr.Audio( | |
| label="Загрузите аудио файл", | |
| type="filepath" | |
| ) | |
| zs_audio_labels = gr.Textbox( | |
| label="Кандидаты (через запятую)", | |
| placeholder="music, speech, noise", | |
| value="music, speech, noise" | |
| ) | |
| zs_audio_model = gr.Dropdown( | |
| choices=[ | |
| "laion/clap-htsat-unfused", | |
| "laion/clap-htsat-fused" | |
| ], | |
| value="laion/clap-htsat-unfused", | |
| label="Выберите модель" | |
| ) | |
| zs_audio_btn = gr.Button("Классифицировать", variant="primary") | |
| with gr.Column(): | |
| zs_audio_output = gr.Textbox(label="Результат", lines=6) | |
| zs_audio_btn.click( | |
| audio_zero_shot_classifier, | |
| inputs=[zs_audio_input, zs_audio_labels, zs_audio_model], | |
| outputs=zs_audio_output | |
| ) | |
| with gr.Accordion("Распознавание речи", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| asr_input = gr.Audio( | |
| label="Загрузите аудио файл с речью", | |
| type="filepath" | |
| ) | |
| asr_model = gr.Dropdown( | |
| choices=[ | |
| "openai/whisper-base", | |
| "facebook/wav2vec2-base-960h", | |
| "jonatasgrosman/wav2vec2-large-xlsr-53-russian" | |
| ], | |
| value="openai/whisper-base", | |
| label="Выберите модель" | |
| ) | |
| asr_btn = gr.Button("Распознать", variant="primary") | |
| with gr.Column(): | |
| asr_output = gr.Textbox(label="Распознанный текст", lines=5) | |
| asr_btn.click( | |
| speech_recognition, | |
| inputs=[asr_input, asr_model], | |
| outputs=asr_output | |
| ) | |
| with gr.Accordion("Синтез речи", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| tts_input = gr.Textbox( | |
| label="Введите текст для синтеза", | |
| placeholder="Hello, this is a speech synthesis test", | |
| value="Hello, this is a speech synthesis test", | |
| lines=3 | |
| ) | |
| tts_model = gr.Dropdown( | |
| choices=[ | |
| "microsoft/speecht5_tts", | |
| "facebook/mms-tts-eng", | |
| "facebook/mms-tts-rus" | |
| ], | |
| value="microsoft/speecht5_tts", | |
| label="Выберите модель" | |
| ) | |
| tts_btn = gr.Button("Синтезировать", variant="primary") | |
| with gr.Column(): | |
| tts_output = gr.Audio(label="Сгенерированное аудио") | |
| tts_btn.click( | |
| speech_synthesis, | |
| inputs=[tts_input, tts_model], | |
| outputs=tts_output | |
| ) | |
| # ========== ВКЛАДКА 3: ИЗОБРАЖЕНИЯ ТРАНСФОРМЕРЫ ========== | |
| with gr.Tab("🖼️ Изображения трансформеры"): | |
| gr.Markdown("## Работа с моделями для изображений") | |
| with gr.Accordion("Обнаружение объектов", open=True): | |
| with gr.Row(): | |
| with gr.Column(): | |
| obj_det_input = gr.Image( | |
| label="Загрузите изображение", | |
| type="pil" | |
| ) | |
| obj_det_model = gr.Dropdown( | |
| choices=[ | |
| "facebook/detr-resnet-50", | |
| "hustvl/yolos-tiny" | |
| ], | |
| value="facebook/detr-resnet-50", | |
| label="Выберите модель" | |
| ) | |
| obj_det_btn = gr.Button("Обнаружить объекты", variant="primary") | |
| with gr.Column(): | |
| obj_det_output = gr.Textbox(label="Результат", lines=8) | |
| obj_det_image = gr.Image(label="Изображение с результатами", type="pil") | |
| obj_det_btn.click( | |
| object_detection, | |
| inputs=[obj_det_input, obj_det_model], | |
| outputs=[obj_det_output, obj_det_image] | |
| ) | |
| with gr.Accordion("Сегментация изображений", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| seg_input = gr.Image( | |
| label="Загрузите изображение", | |
| type="pil" | |
| ) | |
| seg_model = gr.Dropdown( | |
| choices=[ | |
| "facebook/detr-resnet-50-panoptic", | |
| "facebook/maskformer-swin-base-ade" | |
| ], | |
| value="facebook/detr-resnet-50-panoptic", | |
| label="Выберите модель" | |
| ) | |
| seg_btn = gr.Button("Сегментировать", variant="primary") | |
| with gr.Column(): | |
| seg_output = gr.Textbox(label="Результат", lines=8) | |
| seg_image = gr.Image(label="Изображение с результатами", type="pil") | |
| seg_btn.click( | |
| image_segmentation, | |
| inputs=[seg_input, seg_model], | |
| outputs=[seg_output, seg_image] | |
| ) | |
| with gr.Accordion("Описание изображений", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| caption_input = gr.Image( | |
| label="Загрузите изображение", | |
| type="pil" | |
| ) | |
| caption_model = gr.Dropdown( | |
| choices=[ | |
| "nlpconnect/vit-gpt2-image-captioning", | |
| "Salesforce/blip-image-captioning-base", | |
| "microsoft/git-base" | |
| ], | |
| value="nlpconnect/vit-gpt2-image-captioning", | |
| label="Выберите модель" | |
| ) | |
| caption_btn = gr.Button("Сгенерировать описание", variant="primary") | |
| with gr.Column(): | |
| caption_output = gr.Textbox(label="Описание изображения", lines=3) | |
| caption_btn.click( | |
| image_captioning, | |
| inputs=[caption_input, caption_model], | |
| outputs=caption_output | |
| ) | |
| with gr.Accordion("Визуальный вопрос-ответ", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| vqa_image = gr.Image( | |
| label="Загрузите изображение", | |
| type="pil" | |
| ) | |
| vqa_question = gr.Textbox( | |
| label="Вопрос", | |
| placeholder="What is in the image?", | |
| value="What is in the image?" | |
| ) | |
| vqa_model = gr.Dropdown( | |
| choices=[ | |
| "dandelin/vilt-b32-finetuned-vqa", | |
| "Salesforce/blip-vqa-base" | |
| ], | |
| value="dandelin/vilt-b32-finetuned-vqa", | |
| label="Выберите модель" | |
| ) | |
| vqa_btn = gr.Button("Ответить", variant="primary") | |
| with gr.Column(): | |
| vqa_output = gr.Textbox(label="Ответ", lines=3) | |
| vqa_btn.click( | |
| visual_qa, | |
| inputs=[vqa_image, vqa_question, vqa_model], | |
| outputs=vqa_output | |
| ) | |
| with gr.Accordion("Zero-shot классификация", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| zs_image_input = gr.Image( | |
| label="Загрузите изображение", | |
| type="pil" | |
| ) | |
| zs_image_labels = gr.Textbox( | |
| label="Кандидаты (через запятую)", | |
| placeholder="cat, dog, bird", | |
| value="cat, dog, bird" | |
| ) | |
| zs_image_model = gr.Dropdown( | |
| choices=[ | |
| "openai/clip-vit-base-patch32", | |
| "laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k" | |
| ], | |
| value="openai/clip-vit-base-patch32", | |
| label="Выберите модель" | |
| ) | |
| zs_image_btn = gr.Button("Классифицировать", variant="primary") | |
| with gr.Column(): | |
| zs_image_output = gr.Textbox(label="Результат", lines=6) | |
| zs_image_btn.click( | |
| image_zero_shot_classification, | |
| inputs=[zs_image_input, zs_image_labels, zs_image_model], | |
| outputs=zs_image_output | |
| ) | |
| # ========== ВКЛАДКА 4: ИСТОРИЯ ========== | |
| with gr.Tab("📜 История"): | |
| gr.Markdown("## История выполнения моделей") | |
| gr.Markdown("Здесь отображается история всех выполненных операций с моделями. История автоматически обновляется при каждом вызове.") | |
| with gr.Row(): | |
| history_clear_btn = gr.Button("Очистить историю", variant="stop") | |
| history_refresh_btn = gr.Button("Обновить", variant="secondary") | |
| history_table = gr.Dataframe( | |
| label="История выполнения", | |
| headers=["Задача", "Модель", "Время выполнения", "Дата/Время", "Входные данные", "Результат"], | |
| value=get_history_display(), | |
| interactive=False, | |
| wrap=True | |
| ) | |
| history_json = gr.JSON( | |
| label="Полная история (JSON)", | |
| value=history if history else [] | |
| ) | |
| def update_history_display(): | |
| """Обновляет отображение истории""" | |
| display_data = get_history_display() | |
| json_data = history if history else [] | |
| return display_data, json_data | |
| history_refresh_btn.click( | |
| fn=update_history_display, | |
| outputs=[history_table, history_json] | |
| ) | |
| history_clear_btn.click( | |
| fn=clear_history, | |
| outputs=[history_table, history_json] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) | |