Transformers / app.py
MinAA
cleanup
fce9d70
import gradio as gr
from transformers import pipeline
import torch
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import functools
import warnings
import time
import inspect
from datetime import datetime
from collections import OrderedDict
warnings.filterwarnings("ignore")
# LRU кэш для хранения загруженных моделей
class LRUCache:
"""LRU (Least Recently Used) кэш для ограничения использования памяти"""
def __init__(self, maxsize=5):
"""
Args:
maxsize: Максимальное количество моделей в кэше
"""
self.cache = OrderedDict()
self.maxsize = maxsize
def get(self, key):
"""Получить модель из кэша"""
if key not in self.cache:
return None
# Перемещаем элемент в конец (как недавно использованный)
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key, value):
"""Добавить модель в кэш"""
if key in self.cache:
# Если ключ уже есть, обновляем и перемещаем в конец
self.cache.move_to_end(key)
self.cache[key] = value
else:
# Если кэш полон, удаляем самый старый элемент (первый в OrderedDict)
if len(self.cache) >= self.maxsize:
oldest_key = next(iter(self.cache))
# Освобождаем память от модели
old_value = self.cache.pop(oldest_key)
del old_value
# Также очищаем кэш CUDA если используется GPU
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.cache[key] = value
def __contains__(self, key):
"""Проверка наличия ключа в кэше"""
return key in self.cache
def __getitem__(self, key):
"""Получить элемент через []"""
value = self.get(key)
if value is None:
raise KeyError(key)
return value
def __setitem__(self, key, value):
"""Установить элемент через []"""
self.put(key, value)
def clear(self):
"""Очистить кэш"""
self.cache.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def size(self):
"""Текущий размер кэша"""
return len(self.cache)
# Создаем LRU кэш с максимальным размером 5 моделей
# Можно изменить это значение в зависимости от доступной памяти
model_cache = LRUCache(maxsize=2)
# История выполнения моделей
history = []
MAX_HISTORY_SIZE = 50
def get_pipeline(task, model_name, **kwargs):
"""Загрузка pipeline с LRU кэшированием"""
cache_key = f"{task}_{model_name}"
cached_model = model_cache.get(cache_key)
if cached_model is None:
try:
cached_model = pipeline(task, model=model_name, **kwargs)
model_cache.put(cache_key, cached_model)
except Exception as e:
raise Exception(f"Ошибка загрузки модели: {str(e)}")
return cached_model
def measure_time_and_save(task_name):
"""Декоратор для измерения времени выполнения и сохранения в историю"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Извлекаем model_name из аргументов
model_name = None
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
if 'model_name' in bound_args.arguments:
model_name = bound_args.arguments['model_name']
# Создаем краткое описание входных данных
input_preview = ""
if args and len(args) > 0:
first_arg = args[0]
if isinstance(first_arg, str):
input_preview = first_arg[:100] + ("..." if len(first_arg) > 100 else "")
elif isinstance(first_arg, Image.Image):
input_preview = f"Изображение ({first_arg.size[0]}x{first_arg.size[1]})"
elif isinstance(first_arg, (tuple, list)) and len(first_arg) == 2:
# Аудио файл (sample_rate, audio_data)
input_preview = f"Аудио файл"
else:
input_preview = str(type(first_arg).__name__)
# Выполняем функцию
try:
result = func(*args, **kwargs)
if isinstance(result, str):
output = result
elif isinstance(result, tuple) and len(result) == 2:
# Проверяем тип второго элемента
if isinstance(result[1], Image.Image):
# Результат с изображением (текст, изображение)
output = result[0] if isinstance(result[0], str) else str(result[0])[:500]
elif isinstance(result[1], (tuple, list)) and len(result[1]) == 2:
# Аудио результат (sample_rate, audio_data)
output = f"Аудио файл сгенерирован (sample_rate: {result[0]})"
else:
output = str(result)[:500]
else:
output = str(result)[:500]
except Exception as e:
output = f"Ошибка: {str(e)}"
result = output
# Измеряем время выполнения
execution_time = time.time() - start_time
# Сохраняем в историю
history_entry = {
"task_name": task_name,
"model_name": model_name or "Не указана",
"input_preview": input_preview,
"output": output,
"execution_time": round(execution_time, 3),
"timestamp": timestamp
}
history.insert(0, history_entry) # Добавляем в начало
# Ограничиваем размер истории
if len(history) > MAX_HISTORY_SIZE:
history.pop()
return result
return wrapper
return decorator
# ==================== ТЕКСТОВЫЕ ЗАДАЧИ ====================
@measure_time_and_save("Классификатор текста")
def text_classifier(text, model_name):
"""Классификация текста"""
try:
classifier = get_pipeline("text-classification", model_name)
result = classifier(text)
if isinstance(result, list):
result = result[0]
return f"Метка: {result['label']}\nУверенность: {result['score']:.4f}"
except Exception as e:
return f"Ошибка: {str(e)}"
@measure_time_and_save("Zero-shot классификатор")
def zero_shot_classifier(text, candidate_labels, model_name):
"""Zero-shot классификация"""
try:
classifier = get_pipeline("zero-shot-classification", model_name)
labels = [label.strip() for label in candidate_labels.split(",")]
result = classifier(text, labels)
output = "Результаты классификации:\n"
for label, score in zip(result['labels'], result['scores']):
output += f"{label}: {score:.4f}\n"
return output
except Exception as e:
return f"Ошибка: {str(e)}"
@measure_time_and_save("Генератор текста")
def text_generator(prompt, max_length, model_name):
"""Генерация текста"""
try:
generator = get_pipeline("text-generation", model_name)
result = generator(prompt, max_length=max_length, num_return_sequences=1, do_sample=True)
return result[0]['generated_text']
except Exception as e:
return f"Ошибка: {str(e)}"
@measure_time_and_save("Unmasker (заполнение пропусков)")
def text_unmasker(text, model_name):
"""Заполнение пропусков в тексте"""
try:
unmasker = get_pipeline("fill-mask", model_name)
result = unmasker(text)
output = "Варианты заполнения:\n"
for i, item in enumerate(result[:5], 1):
output += f"{i}. {item['sequence']} (уверенность: {item['score']:.4f})\n"
return output
except Exception as e:
return f"Ошибка: {str(e)}"
@measure_time_and_save("NER (Извлечение именованных сущностей)")
def ner_extractor(text, model_name):
"""Извлечение именованных сущностей"""
try:
ner = get_pipeline("ner", model_name, aggregation_strategy="simple")
result = ner(text)
if not result:
return "Именованные сущности не найдены"
output = "Найденные сущности:\n"
for entity in result:
output += f"{entity['word']}: {entity['entity_group']} (уверенность: {entity['score']:.4f})\n"
return output
except Exception as e:
return f"Ошибка: {str(e)}"
@measure_time_and_save("Question Answering")
def question_answerer(question, context, model_name):
"""Ответ на вопрос по контексту"""
try:
qa = get_pipeline("question-answering", model_name)
result = qa(question=question, context=context)
return f"Ответ: {result['answer']}\nУверенность: {result['score']:.4f}"
except Exception as e:
return f"Ошибка: {str(e)}"
@measure_time_and_save("Суммаризатор")
def summarizer(text, max_length, min_length, model_name):
"""Суммаризация текста"""
try:
summarizer_pipe = get_pipeline("summarization", model_name)
result = summarizer_pipe(text, max_length=max_length, min_length=min_length)
if isinstance(result, list):
result = result[0]
return result['summary_text']
except Exception as e:
return f"Ошибка: {str(e)}"
@measure_time_and_save("Переводчик")
def translator(text, model_name, src_lang=None, tgt_lang=None):
"""Перевод текста"""
try:
translator_pipe = get_pipeline("translation", model_name)
# Для mBART моделей требуются src_lang и tgt_lang
if "mbart" in model_name.lower():
if not src_lang or not tgt_lang:
return "Ошибка: Для модели mBART необходимо указать исходный и целевой языки"
result = translator_pipe(text, src_lang=src_lang, tgt_lang=tgt_lang)
else:
result = translator_pipe(text)
if isinstance(result, list):
result = result[0]
return result['translation_text']
except Exception as e:
return f"Ошибка: {str(e)}"
# ==================== АУДИО ЗАДАЧИ ====================
@measure_time_and_save("Классификатор аудио")
def audio_classifier(audio, model_name):
"""Классификация аудио"""
try:
classifier = get_pipeline("audio-classification", model_name)
result = classifier(audio)
# audio-classification pipeline возвращает список словарей
if not isinstance(result, list):
result = [result]
output = "Результаты классификации:\n"
for item in result[:5]:
output += f"{item['label']}: {item['score']:.4f}\n"
return output
except Exception as e:
return f"Ошибка: {str(e)}"
@measure_time_and_save("Zero-shot классификатор аудио")
def audio_zero_shot_classifier(audio, candidate_labels, model_name):
"""Zero-shot классификация аудио"""
try:
classifier = get_pipeline("zero-shot-audio-classification", model_name)
labels = [label.strip() for label in candidate_labels.split(",")]
result = classifier(audio, candidate_labels=labels)
output = "Результаты классификации:\n"
# Парсинг результата: [{'score': ..., 'label': ...}, ...]
if isinstance(result, list):
for item in result:
if isinstance(item, dict) and 'label' in item and 'score' in item:
output += f"{item['label']}: {item['score']:.4f}\n"
else:
return f"Ошибка: Неожиданный формат результата от pipeline: {type(result)}. Ожидался список словарей с ключами 'label' и 'score'."
return output
except Exception as e:
error_msg = str(e)
if "Could not load model" in error_msg or "Unrecognized" in error_msg:
return f"Ошибка: Модель '{model_name}' не поддерживается для zero-shot классификации аудио. Попробуйте другую модель, например 'laion/clap-htsat-unfused'."
return f"Ошибка: {error_msg}"
@measure_time_and_save("Распознавание речи")
def speech_recognition(audio, model_name):
"""Распознавание речи"""
try:
asr = get_pipeline("automatic-speech-recognition", model_name)
result = asr(audio)
if isinstance(result, list):
result = result[0]
return result['text']
except Exception as e:
return f"Ошибка: {str(e)}"
@measure_time_and_save("Синтез речи")
def speech_synthesis(text, model_name):
"""Синтез речи"""
try:
import numpy as np
import torch
# Проверяем, что текст не пустой
if not text or not text.strip():
raise ValueError("Текст для синтеза не может быть пустым")
# Проверяем, является ли модель SpeechT5
if "speecht5" in model_name.lower():
try:
# Для SpeechT5 нужны speaker_embeddings
# Пробуем использовать pipeline с forward_params
try:
tts = get_pipeline("text-to-speech", model_name)
# Пытаемся загрузить предобученные speaker embeddings из датасета
try:
from datasets import load_dataset
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
except Exception:
# Если не удалось загрузить, создаем случайный speaker embedding
# Сначала загружаем модель чтобы узнать размерность
from transformers import SpeechT5ForTextToSpeech
temp_model = SpeechT5ForTextToSpeech.from_pretrained(model_name)
speaker_embedding_dim = temp_model.config.speaker_embedding_dim
del temp_model
# Создаем случайный speaker embedding
speaker_embedding = torch.randn(1, speaker_embedding_dim)
speaker_embedding = speaker_embedding / torch.norm(speaker_embedding, dim=1, keepdim=True)
# Используем pipeline с forward_params
result = tts(text, forward_params={"speaker_embeddings": speaker_embedding})
# Обрабатываем результат
if isinstance(result, dict):
audio_data = result.get("audio", result.get("raw", None))
sample_rate = result.get("sampling_rate", result.get("sample_rate", 16000))
if audio_data is None:
raise ValueError("Не удалось извлечь аудио данные из результата pipeline")
# Конвертируем в numpy array если нужно
if isinstance(audio_data, torch.Tensor):
audio_data = audio_data.numpy()
elif not isinstance(audio_data, np.ndarray):
audio_data = np.array(audio_data)
# Убеждаемся, что это 1D массив
if len(audio_data.shape) > 1:
audio_data = audio_data.flatten()
# Нормализуем в float32
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
# Нормализуем если значения выходят за пределы [-1, 1]
max_val = np.abs(audio_data).max()
if max_val > 1.0:
audio_data = audio_data / max_val
return (sample_rate, audio_data)
else:
raise ValueError(f"Неожиданный формат результата от pipeline: {type(result)}")
except Exception as pipeline_error:
# Если pipeline не работает, пробуем напрямую через модель
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
cache_key = f"tts_speecht5_{model_name}"
cached = model_cache.get(cache_key)
if cached is None:
processor = SpeechT5Processor.from_pretrained(model_name)
model = SpeechT5ForTextToSpeech.from_pretrained(model_name)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
# Пытаемся загрузить предобученные speaker embeddings
try:
from datasets import load_dataset
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
except Exception:
# Если не удалось, создаем случайный
speaker_embedding_dim = model.config.speaker_embedding_dim
speaker_embeddings = torch.randn(1, speaker_embedding_dim)
speaker_embeddings = speaker_embeddings / torch.norm(speaker_embeddings, dim=1, keepdim=True)
cached = (processor, model, vocoder, speaker_embeddings)
model_cache.put(cache_key, cached)
processor, model, vocoder, speaker_embeddings = cached
inputs = processor(text=text, return_tensors="pt")
# Убеждаемся, что speaker_embeddings имеют правильную форму и тип
if not isinstance(speaker_embeddings, torch.Tensor):
speaker_embeddings = torch.tensor(speaker_embeddings)
if speaker_embeddings.device != inputs["input_ids"].device:
speaker_embeddings = speaker_embeddings.to(inputs["input_ids"].device)
with torch.no_grad():
# Вызываем generate_speech с именованным параметром
speech = model.generate_speech(
inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
vocoder=vocoder
)
# Конвертируем в numpy и нормализуем
audio_data = speech.numpy()
# Убеждаемся, что это 1D массив
if len(audio_data.shape) > 1:
audio_data = audio_data.flatten()
# Нормализуем в диапазон [-1, 1] если нужно
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
# Нормализуем если значения выходят за пределы [-1, 1]
max_val = np.abs(audio_data).max()
if max_val > 1.0:
audio_data = audio_data / max_val
sample_rate = 16000
return (sample_rate, audio_data)
except Exception as e:
error_msg = str(e)
if "ImportError" in str(type(e)) or "ModuleNotFoundError" in str(type(e)):
raise Exception(f"Ошибка: Не удалось импортировать необходимые модули для SpeechT5. Убедитесь, что transformers установлен: {error_msg}")
raise Exception(f"Ошибка синтеза речи с SpeechT5: {error_msg}")
# Используем стандартный pipeline для других моделей
tts = get_pipeline("text-to-speech", model_name)
result = tts(text)
# Pipeline может возвращать словарь или кортеж
if isinstance(result, dict):
# Стандартный формат: {"audio": array, "sampling_rate": int}
audio_data = result.get("audio", result.get("raw", None))
sample_rate = result.get("sampling_rate", result.get("sample_rate", 22050))
if audio_data is None:
raise ValueError("Не удалось извлечь аудио данные из результата pipeline")
# Конвертируем в numpy array если нужно
if isinstance(audio_data, torch.Tensor):
audio_data = audio_data.numpy()
elif not isinstance(audio_data, np.ndarray):
audio_data = np.array(audio_data)
# Убеждаемся, что это 1D массив
if len(audio_data.shape) > 1:
audio_data = audio_data.flatten()
# Нормализуем в float32
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
# Нормализуем если значения выходят за пределы [-1, 1]
max_val = np.abs(audio_data).max()
if max_val > 1.0:
audio_data = audio_data / max_val
return (sample_rate, audio_data)
elif isinstance(result, tuple) and len(result) == 2:
# Уже в правильном формате (sample_rate, audio_data)
sample_rate, audio_data = result
# Конвертируем в numpy если нужно
if isinstance(audio_data, torch.Tensor):
audio_data = audio_data.numpy()
elif not isinstance(audio_data, np.ndarray):
audio_data = np.array(audio_data)
# Убеждаемся, что это 1D массив
if len(audio_data.shape) > 1:
audio_data = audio_data.flatten()
# Нормализуем в float32
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
# Нормализуем если значения выходят за пределы [-1, 1]
max_val = np.abs(audio_data).max()
if max_val > 1.0:
audio_data = audio_data / max_val
return (sample_rate, audio_data)
else:
raise ValueError(f"Неожиданный формат результата от pipeline: {type(result)}. Ожидался словарь с ключами 'audio' и 'sampling_rate' или кортеж (sample_rate, audio_data).")
except Exception as e:
error_msg = str(e)
if "speaker_embeddings" in error_msg.lower():
if "speecht5" in model_name.lower():
return f"Ошибка: Модель SpeechT5 требует speaker_embeddings. Они должны генерироваться автоматически, но произошла ошибка: {error_msg}"
return f"Ошибка: Модель '{model_name}' требует speaker_embeddings. Для SpeechT5 они генерируются автоматически, но для других моделей может потребоваться дополнительная настройка."
if "does not appear to have a file named" in error_msg or "Unrecognized model" in error_msg:
return f"Ошибка: Модель '{model_name}' не поддерживается библиотекой transformers для синтеза речи. Попробуйте использовать модель 'microsoft/speecht5_tts'."
if "negative output size" in error_msg.lower() or "input size 0" in error_msg.lower():
return f"Ошибка: Проблема с обработкой текста моделью '{model_name}'. Возможные причины: неподдерживаемый язык, пустой текст после обработки, или проблема с токенизацией. Попробуйте использовать другой текст или модель."
raise Exception(f"Ошибка синтеза речи: {error_msg}")
# ==================== ЗАДАЧИ С ИЗОБРАЖЕНИЯМИ ====================
@measure_time_and_save("Обнаружение объектов")
def object_detection(image, model_name):
"""Обнаружение объектов на изображении"""
try:
detector = get_pipeline("object-detection", model_name)
result = detector(image)
# Создаем копию изображения для визуализации
img_with_boxes = image.copy()
draw = ImageDraw.Draw(img_with_boxes)
# Цвета для разных объектов
colors = ['red', 'blue', 'green', 'yellow', 'orange', 'purple', 'cyan', 'magenta']
output = "Обнаруженные объекты:\n"
for i, item in enumerate(result):
box = item['box']
label = item['label']
score = item['score']
# Обрабатываем различные форматы координат
if isinstance(box, dict):
# Словарь с ключами 'xmin', 'ymin', 'xmax', 'ymax'
xmin = box.get('xmin', box.get('x1', 0))
ymin = box.get('ymin', box.get('y1', 0))
xmax = box.get('xmax', box.get('x2', 0))
ymax = box.get('ymax', box.get('y2', 0))
elif isinstance(box, (list, tuple)) and len(box) >= 4:
# Список [xmin, ymin, xmax, ymax] или [xcenter, ycenter, width, height]
if box[2] > box[0] and box[3] > box[1]:
# Вероятно [xmin, ymin, xmax, ymax]
xmin, ymin, xmax, ymax = box[0], box[1], box[2], box[3]
else:
# Вероятно [xcenter, ycenter, width, height]
xcenter, ycenter, width, height = box[0], box[1], box[2], box[3]
xmin = xcenter - width / 2
ymin = ycenter - height / 2
xmax = xcenter + width / 2
ymax = ycenter + height / 2
else:
# Неизвестный формат, пропускаем
output += f"{label}: уверенность {score:.4f}, координаты {box}\n"
continue
# Проверяем и ограничиваем координаты границами изображения
img_width, img_height = img_with_boxes.size
xmin = max(0, min(xmin, img_width))
ymin = max(0, min(ymin, img_height))
xmax = max(0, min(xmax, img_width))
ymax = max(0, min(ymax, img_height))
# Проверяем, что координаты валидны
if xmax <= xmin or ymax <= ymin:
output += f"{label}: уверенность {score:.4f}, координаты {box} (некорректные)\n"
continue
# Рисуем прямоугольник
color = colors[i % len(colors)]
draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=3)
# Добавляем текст с меткой и уверенностью
text = f"{label}: {score:.2f}"
try:
# Пытаемся использовать системный шрифт
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 16)
except:
try:
font = ImageFont.load_default()
except:
font = None
# Получаем размер текста
if font:
bbox = draw.textbbox((0, 0), text, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
else:
text_width = len(text) * 6
text_height = 12
# Рисуем фон для текста (проверяем границы)
text_y = max(0, ymin - text_height - 4)
text_x_end = min(img_width, xmin + text_width + 4)
draw.rectangle([xmin, text_y, text_x_end, ymin], fill=color)
draw.text((xmin + 2, text_y + 2), text, fill='white', font=font)
output += f"{label}: уверенность {score:.4f}, координаты {box}\n"
return output, img_with_boxes
except Exception as e:
return f"Ошибка: {str(e)}", image
@measure_time_and_save("Сегментация изображений")
def image_segmentation(image, model_name):
"""Сегментация изображения"""
try:
segmenter = get_pipeline("image-segmentation", model_name)
result = segmenter(image)
# Создаем копию изображения для визуализации
img_with_segments = image.copy().convert("RGBA")
# Генерируем цвета для сегментов
np.random.seed(42) # Для воспроизводимости
output = "Сегменты:\n"
overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
# Список для хранения информации о сегментах (для добавления текста)
segments_info = []
for i, item in enumerate(result):
label = item['label']
score = item['score']
# Генерируем полупрозрачный цвет для сегмента
color = tuple(np.random.randint(0, 255, 3)) + (128,) # RGBA с прозрачностью
color_rgb = color[:3] # RGB цвет для текста
# Проверяем наличие маски
if 'mask' in item:
mask = item['mask']
# Преобразуем маску в numpy array
if isinstance(mask, Image.Image):
mask_array = np.array(mask)
elif isinstance(mask, np.ndarray):
mask_array = mask
else:
mask_array = np.array(mask)
# Нормализуем маску, если нужно
if mask_array.dtype != np.uint8:
if mask_array.max() <= 1.0:
mask_array = (mask_array * 255).astype(np.uint8)
else:
mask_array = mask_array.astype(np.uint8)
# Находим центр маски для размещения текста
if len(mask_array.shape) == 2: # Grayscale mask
mask_bool = mask_array > 0
elif len(mask_array.shape) == 3 and mask_array.shape[2] == 1:
mask_bool = mask_array[:, :, 0] > 0
else:
if mask_array.shape[2] >= 1:
mask_bool = mask_array[:, :, 0] > 0
else:
mask_bool = np.zeros(mask_array.shape[:2], dtype=bool)
# Вычисляем центр маски
if np.any(mask_bool):
y_coords, x_coords = np.where(mask_bool)
if len(y_coords) > 0 and len(x_coords) > 0:
center_y = int(np.mean(y_coords))
center_x = int(np.mean(x_coords))
# Масштабируем координаты, если маска другого размера
if mask_array.shape[:2] != image.size[::-1]:
scale_y = image.size[1] / mask_array.shape[0]
scale_x = image.size[0] / mask_array.shape[1]
center_y = int(center_y * scale_y)
center_x = int(center_x * scale_x)
segments_info.append({
'label': label,
'score': score,
'center': (center_x, center_y),
'color': color_rgb
})
# Создаем цветную маску
if len(mask_array.shape) == 2: # Grayscale mask
# Создаем RGBA маску
colored_mask = np.zeros((mask_array.shape[0], mask_array.shape[1], 4), dtype=np.uint8)
# Применяем цвет только там, где маска не равна нулю
mask_bool = mask_array > 0
colored_mask[mask_bool, :3] = color[:3]
colored_mask[mask_bool, 3] = 128 # Альфа-канал
elif len(mask_array.shape) == 3 and mask_array.shape[2] == 1:
# Маска с одним каналом
colored_mask = np.zeros((mask_array.shape[0], mask_array.shape[1], 4), dtype=np.uint8)
mask_bool = mask_array[:, :, 0] > 0
colored_mask[mask_bool, :3] = color[:3]
colored_mask[mask_bool, 3] = 128
else:
# Многоканальная маска
colored_mask = np.zeros((mask_array.shape[0], mask_array.shape[1], 4), dtype=np.uint8)
# Используем первый канал как маску
if mask_array.shape[2] >= 1:
mask_bool = mask_array[:, :, 0] > 0
colored_mask[mask_bool, :3] = color[:3]
colored_mask[mask_bool, 3] = 128
# Убеждаемся, что размеры совпадают
if colored_mask.shape[:2] == img_with_segments.size[::-1]:
mask_img = Image.fromarray(colored_mask, mode='RGBA')
overlay = Image.alpha_composite(overlay, mask_img)
elif colored_mask.shape[:2] != overlay.size[::-1]:
# Изменяем размер маски, если нужно
mask_img = Image.fromarray(colored_mask, mode='RGBA')
mask_img = mask_img.resize(overlay.size, Image.Resampling.LANCZOS)
overlay = Image.alpha_composite(overlay, mask_img)
output += f"{label}: уверенность {score:.4f}\n"
# Накладываем overlay на исходное изображение
if overlay.size == img_with_segments.size:
img_with_segments = Image.alpha_composite(img_with_segments, overlay)
# Добавляем текстовые метки с цветами на изображение
draw_final = ImageDraw.Draw(img_with_segments)
# Загружаем шрифт
try:
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 18)
except:
try:
font = ImageFont.load_default()
except:
font = None
for seg_info in segments_info:
label = seg_info['label']
score = seg_info['score']
center_x, center_y = seg_info['center']
color_rgb = seg_info['color']
# Формируем текст метки
text = f"{label}: {score:.2f}"
# Получаем размер текста
if font:
bbox = draw_final.textbbox((0, 0), text, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
else:
text_width = len(text) * 7
text_height = 14
# Вычисляем позицию текста (центрируем относительно центра сегмента)
text_x = center_x - text_width // 2
text_y = center_y - text_height // 2
# Ограничиваем координаты границами изображения
img_width, img_height = img_with_segments.size
text_x = max(2, min(text_x, img_width - text_width - 2))
text_y = max(2, min(text_y, img_height - text_height - 2))
# Рисуем фон для текста (полупрозрачный черный для читаемости)
padding = 4
draw_final.rectangle(
[text_x - padding, text_y - padding,
text_x + text_width + padding, text_y + text_height + padding],
fill=(0, 0, 0, 180) # Полупрозрачный черный фон
)
# Рисуем текст цветом сегмента
draw_final.text(
(text_x, text_y),
text,
fill=color_rgb + (255,), # RGB + альфа для RGBA
font=font
)
# Конвертируем обратно в RGB для отображения
img_with_segments = img_with_segments.convert("RGB")
return output, img_with_segments
except Exception as e:
return f"Ошибка: {str(e)}", image
@measure_time_and_save("Описание изображений")
def image_captioning(image, model_name):
"""Описание изображения"""
try:
captioner = get_pipeline("image-to-text", model_name)
result = captioner(image)
if isinstance(result, list):
result = result[0]
return result['generated_text']
except Exception as e:
return f"Ошибка: {str(e)}"
@measure_time_and_save("Визуальный вопрос-ответ")
def visual_qa(image, question, model_name):
"""Визуальный вопрос-ответ"""
try:
vqa = get_pipeline("visual-question-answering", model_name)
result = vqa(image=image, question=question)
if isinstance(result, list):
result = result[0]
return f"Ответ: {result['answer']}"
except Exception as e:
return f"Ошибка: {str(e)}"
@measure_time_and_save("Zero-shot классификация изображений")
def image_zero_shot_classification(image, candidate_labels, model_name):
"""Zero-shot классификация изображений"""
try:
labels = [label.strip() for label in candidate_labels.split(",")]
# Проверяем, является ли модель LAION (проверяем ДО вызова get_pipeline)
model_name_lower = model_name.lower()
if "laion/" in model_name_lower or "laion5b" in model_name_lower or "laion" in model_name_lower:
# Используем OpenCLIP для LAION моделей
try:
import open_clip
except ImportError:
return f"Ошибка: Для работы с LAION моделями требуется библиотека open-clip-torch. Установите её: pip install open-clip-torch"
cache_key = f"clip_laion_{model_name}"
cached = model_cache.get(cache_key)
if cached is None:
# Определяем имя модели и веса для OpenCLIP
if "xlm-roberta-base-ViT-B-32" in model_name or "xlm-roberta-base" in model_name:
clip_model_name = "xlm-roberta-base-ViT-B-32"
pretrained = "laion5b_s13b_b90k"
else:
# Пытаемся извлечь информацию из имени модели
clip_model_name = "xlm-roberta-base-ViT-B-32"
pretrained = "laion5b_s13b_b90k"
model, _, preprocess = open_clip.create_model_and_transforms(
clip_model_name,
pretrained=pretrained
)
tokenizer = open_clip.get_tokenizer(clip_model_name)
model.eval()
cached = (model, preprocess, tokenizer)
model_cache.put(cache_key, cached)
model, preprocess, tokenizer = cached
# Обрабатываем изображение и тексты
image_tensor = preprocess(image).unsqueeze(0)
text_tokens = tokenizer(labels)
with torch.no_grad():
image_features = model.encode_image(image_tensor)
text_features = model.encode_text(text_tokens)
# Нормализуем признаки
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Вычисляем косинусное сходство (логиты)
logits_per_image = (image_features @ text_features.T) * 100 # Масштабируем для лучшей точности
probs = logits_per_image.softmax(dim=1)
output = "Результаты классификации:\n"
for label, prob in zip(labels, probs[0]):
output += f"{label}: {prob.item():.4f}\n"
return output
else:
# Используем стандартный pipeline
classifier = get_pipeline("zero-shot-image-classification", model_name)
result = classifier(image, candidate_labels=labels)
output = "Результаты классификации:\n"
# Парсим результат
if isinstance(result, list):
# Результат - список словарей с 'score' и 'label'
for item in result:
if isinstance(item, dict) and 'label' in item and 'score' in item:
output += f"{item['label']}: {item['score']:.4f}\n"
else:
return f"Ошибка: Неожиданный формат элемента в результате: {item}"
elif isinstance(result, dict):
# Результат - словарь с 'labels' и 'scores'
if 'labels' in result and 'scores' in result:
for label, score in zip(result['labels'], result['scores']):
output += f"{label}: {score:.4f}\n"
else:
return f"Ошибка: Неожиданный формат результата от pipeline: {result}. Ожидался словарь с ключами 'labels' и 'scores' или список словарей."
else:
return f"Ошибка: Неожиданный формат результата от pipeline: {type(result)}. Ожидался словарь с ключами 'labels' и 'scores' или список словарей. Получен: {result}"
return output
except Exception as e:
error_msg = str(e)
if "Could not load model" in error_msg or "Unrecognized" in error_msg:
if "laion" in model_name.lower():
return f"Ошибка: Модель '{model_name}' требует библиотеку open-clip-torch. Убедитесь, что она установлена: pip install open-clip-torch"
return f"Ошибка: Модель '{model_name}' не поддерживается для zero-shot классификации изображений. Попробуйте другую модель, например 'openai/clip-vit-base-patch32'."
if "open_clip" in error_msg or "open-clip" in error_msg or "ModuleNotFoundError" in str(type(e)):
return f"Ошибка: Для работы с LAION моделями требуется библиотека open-clip-torch. Установите её: pip install open-clip-torch"
return f"Ошибка: {error_msg}"
# ==================== ФУНКЦИИ ДЛЯ ИСТОРИИ ====================
def get_history_display():
"""Форматирует историю для отображения в таблице"""
if not history:
return []
# Форматируем данные для таблицы
display_data = []
for entry in history:
# Обрезаем длинные результаты для отображения
output_preview = entry['output']
if len(output_preview) > 100:
output_preview = output_preview[:100] + "..."
display_data.append([
entry['task_name'],
entry['model_name'],
f"{entry['execution_time']} сек",
entry['timestamp'],
entry['input_preview'],
output_preview
])
return display_data
def clear_history():
"""Очищает историю"""
global history
history.clear()
return get_history_display(), []
# ==================== GRADIO ИНТЕРФЕЙС ====================
with gr.Blocks(title="Трансформеры Hugging Face", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🤖 Трансформеры Hugging Face")
gr.Markdown("Выберите вкладку для работы с различными типами трансформеров")
with gr.Tabs():
# ========== ВКЛАДКА 1: ТЕКСТОВЫЕ ТРАНСФОРМЕРЫ ==========
with gr.Tab("📝 Текстовые трансформеры"):
gr.Markdown("## Работа с текстовыми моделями")
with gr.Accordion("Классификатор", open=True):
with gr.Row():
with gr.Column():
text_classifier_input = gr.Textbox(
label="Введите текст для классификации",
placeholder="I love this app!",
value="I love this app!"
)
text_classifier_model = gr.Dropdown(
choices=[
"distilbert-base-uncased-finetuned-sst-2-english",
"cardiffnlp/twitter-roberta-base-sentiment-latest",
"nlptown/bert-base-multilingual-uncased-sentiment"
],
value="distilbert-base-uncased-finetuned-sst-2-english",
label="Выберите модель"
)
text_classifier_btn = gr.Button("Классифицировать", variant="primary")
with gr.Column():
text_classifier_output = gr.Textbox(label="Результат", lines=8)
text_classifier_btn.click(
text_classifier,
inputs=[text_classifier_input, text_classifier_model],
outputs=text_classifier_output
)
with gr.Accordion("Zero-shot классификатор", open=False):
with gr.Row():
with gr.Column():
zs_text_input = gr.Textbox(
label="Введите текст",
placeholder="I just finished reading a great book",
value="I just finished reading a great book"
)
zs_text_labels = gr.Textbox(
label="Кандидаты (через запятую)",
placeholder="positive, negative, neutral",
value="positive, negative, neutral"
)
zs_text_model = gr.Dropdown(
choices=[
"facebook/bart-large-mnli",
"typeform/distilbert-base-uncased-mnli",
"valhalla/distilbart-mnli-12-3"
],
value="facebook/bart-large-mnli",
label="Выберите модель"
)
zs_text_btn = gr.Button("Классифицировать", variant="primary")
with gr.Column():
zs_text_output = gr.Textbox(label="Результат", lines=8)
zs_text_btn.click(
zero_shot_classifier,
inputs=[zs_text_input, zs_text_labels, zs_text_model],
outputs=zs_text_output
)
with gr.Accordion("Генератор текста", open=False):
with gr.Row():
with gr.Column():
text_gen_input = gr.Textbox(
label="Промпт",
placeholder="In the distant future",
value="In the distant future"
)
text_gen_length = gr.Slider(20, 200, value=50, step=10, label="Максимальная длина")
text_gen_model = gr.Dropdown(
choices=["gpt2", "distilgpt2", "EleutherAI/gpt-neo-125M"],
value="gpt2",
label="Выберите модель"
)
text_gen_btn = gr.Button("Сгенерировать", variant="primary")
with gr.Column():
text_gen_output = gr.Textbox(label="Сгенерированный текст", lines=12)
text_gen_btn.click(
text_generator,
inputs=[text_gen_input, text_gen_length, text_gen_model],
outputs=text_gen_output
)
with gr.Accordion("Unmasker (заполнение пропусков)", open=False):
with gr.Row():
with gr.Column():
unmasker_input = gr.Textbox(
label="Текст с [MASK]",
placeholder="I love [MASK] programming",
value="I love [MASK] programming"
)
unmasker_model = gr.Dropdown(
choices=[
"bert-base-uncased",
"distilbert-base-uncased",
"bert-base-multilingual-uncased"
],
value="bert-base-uncased",
label="Выберите модель"
)
unmasker_btn = gr.Button("Заполнить", variant="primary")
with gr.Column():
unmasker_output = gr.Textbox(label="Результат", lines=10)
unmasker_btn.click(
text_unmasker,
inputs=[unmasker_input, unmasker_model],
outputs=unmasker_output
)
with gr.Accordion("NER (Извлечение именованных сущностей)", open=False):
with gr.Row():
with gr.Column():
ner_input = gr.Textbox(
label="Введите текст",
placeholder="My name is John, I work at Microsoft in Seattle",
value="My name is John, I work at Microsoft in Seattle"
)
ner_model = gr.Dropdown(
choices=[
"dslim/bert-base-NER",
"dbmdz/bert-large-cased-finetuned-conll03-english",
"Jean-Baptiste/roberta-large-ner-english"
],
value="dslim/bert-base-NER",
label="Выберите модель"
)
ner_btn = gr.Button("Извлечь сущности", variant="primary")
with gr.Column():
ner_output = gr.Textbox(label="Найденные сущности", lines=10)
ner_btn.click(
ner_extractor,
inputs=[ner_input, ner_model],
outputs=ner_output
)
with gr.Accordion("Question Answering", open=False):
with gr.Row():
with gr.Column():
qa_question = gr.Textbox(
label="Вопрос",
placeholder="What color is the sky?",
value="What color is the sky?"
)
qa_context = gr.Textbox(
label="Контекст",
placeholder="The sky is blue due to light scattering",
value="The sky is blue due to light scattering in the atmosphere",
lines=3
)
qa_model = gr.Dropdown(
choices=[
"distilbert-base-uncased-distilled-squad",
"deepset/roberta-base-squad2",
"bert-large-uncased-whole-word-masking-finetuned-squad"
],
value="distilbert-base-uncased-distilled-squad",
label="Выберите модель"
)
qa_btn = gr.Button("Ответить", variant="primary")
with gr.Column():
qa_output = gr.Textbox(label="Ответ", lines=8)
qa_btn.click(
question_answerer,
inputs=[qa_question, qa_context, qa_model],
outputs=qa_output
)
with gr.Accordion("Суммаризатор", open=False):
with gr.Row():
with gr.Column():
summarizer_input = gr.Textbox(
label="Текст для суммаризации",
placeholder="Enter a long text...",
value="Artificial intelligence is a field of computer science that focuses on creating intelligent machines. Machine learning is a subset of artificial intelligence that enables systems to automatically learn and improve from experience.",
lines=5
)
summarizer_max = gr.Slider(20, 200, value=50, step=10, label="Максимальная длина")
summarizer_min = gr.Slider(10, 100, value=20, step=10, label="Минимальная длина")
summarizer_model = gr.Dropdown(
choices=[
"facebook/bart-large-cnn",
"google/pegasus-xsum",
"t5-small"
],
value="facebook/bart-large-cnn",
label="Выберите модель"
)
summarizer_btn = gr.Button("Суммаризировать", variant="primary")
with gr.Column():
summarizer_output = gr.Textbox(label="Краткое содержание", lines=12)
summarizer_btn.click(
summarizer,
inputs=[summarizer_input, summarizer_max, summarizer_min, summarizer_model],
outputs=summarizer_output
)
with gr.Accordion("Переводчик", open=False):
with gr.Row():
with gr.Column():
translator_input = gr.Textbox(
label="Текст для перевода",
placeholder="Hello, how are you?",
value="Hello, how are you?",
lines=3
)
translator_model = gr.Dropdown(
choices=[
"Helsinki-NLP/opus-mt-en-ru",
"Helsinki-NLP/opus-mt-ru-en",
"facebook/mbart-large-50-many-to-many-mmt"
],
value="Helsinki-NLP/opus-mt-en-ru",
label="Выберите модель"
)
translator_src_lang = gr.Dropdown(
choices=[
"en_XX", "ru_RU", "de_DE", "fr_XX", "es_XX",
"it_IT", "pt_XX", "ja_XX", "ko_KR", "zh_CN",
"ar_AR", "hi_IN", "tr_TR", "vi_VN", "th_TH"
],
value="en_XX",
label="Исходный язык (для mBART)",
visible=False
)
translator_tgt_lang = gr.Dropdown(
choices=[
"en_XX", "ru_RU", "de_DE", "fr_XX", "es_XX",
"it_IT", "pt_XX", "ja_XX", "ko_KR", "zh_CN",
"ar_AR", "hi_IN", "tr_TR", "vi_VN", "th_TH"
],
value="ru_RU",
label="Целевой язык (для mBART)",
visible=False
)
translator_btn = gr.Button("Перевести", variant="primary")
with gr.Column():
translator_output = gr.Textbox(label="Перевод", lines=8)
def update_lang_visibility(model_name):
"""Показывает/скрывает поля языков в зависимости от модели"""
is_mbart = "mbart" in model_name.lower()
return gr.update(visible=is_mbart), gr.update(visible=is_mbart)
translator_model.change(
fn=update_lang_visibility,
inputs=translator_model,
outputs=[translator_src_lang, translator_tgt_lang]
)
translator_btn.click(
translator,
inputs=[translator_input, translator_model, translator_src_lang, translator_tgt_lang],
outputs=translator_output
)
# ========== ВКЛАДКА 2: АУДИО ТРАНСФОРМЕРЫ ==========
with gr.Tab("🎵 Аудио трансформеры"):
gr.Markdown("## Работа с аудио моделями")
with gr.Accordion("Классификатор", open=True):
with gr.Row():
with gr.Column():
audio_classifier_input = gr.Audio(
label="Загрузите аудио файл",
type="filepath"
)
audio_classifier_model = gr.Dropdown(
choices=[
"MIT/ast-finetuned-audioset-10-10-0.4593",
"ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition",
"superb/hubert-base-superb-er"
],
value="MIT/ast-finetuned-audioset-10-10-0.4593",
label="Выберите модель"
)
audio_classifier_btn = gr.Button("Классифицировать", variant="primary")
with gr.Column():
audio_classifier_output = gr.Textbox(label="Результат", lines=6)
audio_classifier_btn.click(
audio_classifier,
inputs=[audio_classifier_input, audio_classifier_model],
outputs=audio_classifier_output
)
with gr.Accordion("Zero-shot классификатор", open=False):
with gr.Row():
with gr.Column():
zs_audio_input = gr.Audio(
label="Загрузите аудио файл",
type="filepath"
)
zs_audio_labels = gr.Textbox(
label="Кандидаты (через запятую)",
placeholder="music, speech, noise",
value="music, speech, noise"
)
zs_audio_model = gr.Dropdown(
choices=[
"laion/clap-htsat-unfused",
"laion/clap-htsat-fused"
],
value="laion/clap-htsat-unfused",
label="Выберите модель"
)
zs_audio_btn = gr.Button("Классифицировать", variant="primary")
with gr.Column():
zs_audio_output = gr.Textbox(label="Результат", lines=6)
zs_audio_btn.click(
audio_zero_shot_classifier,
inputs=[zs_audio_input, zs_audio_labels, zs_audio_model],
outputs=zs_audio_output
)
with gr.Accordion("Распознавание речи", open=False):
with gr.Row():
with gr.Column():
asr_input = gr.Audio(
label="Загрузите аудио файл с речью",
type="filepath"
)
asr_model = gr.Dropdown(
choices=[
"openai/whisper-base",
"facebook/wav2vec2-base-960h",
"jonatasgrosman/wav2vec2-large-xlsr-53-russian"
],
value="openai/whisper-base",
label="Выберите модель"
)
asr_btn = gr.Button("Распознать", variant="primary")
with gr.Column():
asr_output = gr.Textbox(label="Распознанный текст", lines=5)
asr_btn.click(
speech_recognition,
inputs=[asr_input, asr_model],
outputs=asr_output
)
with gr.Accordion("Синтез речи", open=False):
with gr.Row():
with gr.Column():
tts_input = gr.Textbox(
label="Введите текст для синтеза",
placeholder="Hello, this is a speech synthesis test",
value="Hello, this is a speech synthesis test",
lines=3
)
tts_model = gr.Dropdown(
choices=[
"microsoft/speecht5_tts",
"facebook/mms-tts-eng",
"facebook/mms-tts-rus"
],
value="microsoft/speecht5_tts",
label="Выберите модель"
)
tts_btn = gr.Button("Синтезировать", variant="primary")
with gr.Column():
tts_output = gr.Audio(label="Сгенерированное аудио")
tts_btn.click(
speech_synthesis,
inputs=[tts_input, tts_model],
outputs=tts_output
)
# ========== ВКЛАДКА 3: ИЗОБРАЖЕНИЯ ТРАНСФОРМЕРЫ ==========
with gr.Tab("🖼️ Изображения трансформеры"):
gr.Markdown("## Работа с моделями для изображений")
with gr.Accordion("Обнаружение объектов", open=True):
with gr.Row():
with gr.Column():
obj_det_input = gr.Image(
label="Загрузите изображение",
type="pil"
)
obj_det_model = gr.Dropdown(
choices=[
"facebook/detr-resnet-50",
"hustvl/yolos-tiny"
],
value="facebook/detr-resnet-50",
label="Выберите модель"
)
obj_det_btn = gr.Button("Обнаружить объекты", variant="primary")
with gr.Column():
obj_det_output = gr.Textbox(label="Результат", lines=8)
obj_det_image = gr.Image(label="Изображение с результатами", type="pil")
obj_det_btn.click(
object_detection,
inputs=[obj_det_input, obj_det_model],
outputs=[obj_det_output, obj_det_image]
)
with gr.Accordion("Сегментация изображений", open=False):
with gr.Row():
with gr.Column():
seg_input = gr.Image(
label="Загрузите изображение",
type="pil"
)
seg_model = gr.Dropdown(
choices=[
"facebook/detr-resnet-50-panoptic",
"facebook/maskformer-swin-base-ade"
],
value="facebook/detr-resnet-50-panoptic",
label="Выберите модель"
)
seg_btn = gr.Button("Сегментировать", variant="primary")
with gr.Column():
seg_output = gr.Textbox(label="Результат", lines=8)
seg_image = gr.Image(label="Изображение с результатами", type="pil")
seg_btn.click(
image_segmentation,
inputs=[seg_input, seg_model],
outputs=[seg_output, seg_image]
)
with gr.Accordion("Описание изображений", open=False):
with gr.Row():
with gr.Column():
caption_input = gr.Image(
label="Загрузите изображение",
type="pil"
)
caption_model = gr.Dropdown(
choices=[
"nlpconnect/vit-gpt2-image-captioning",
"Salesforce/blip-image-captioning-base",
"microsoft/git-base"
],
value="nlpconnect/vit-gpt2-image-captioning",
label="Выберите модель"
)
caption_btn = gr.Button("Сгенерировать описание", variant="primary")
with gr.Column():
caption_output = gr.Textbox(label="Описание изображения", lines=3)
caption_btn.click(
image_captioning,
inputs=[caption_input, caption_model],
outputs=caption_output
)
with gr.Accordion("Визуальный вопрос-ответ", open=False):
with gr.Row():
with gr.Column():
vqa_image = gr.Image(
label="Загрузите изображение",
type="pil"
)
vqa_question = gr.Textbox(
label="Вопрос",
placeholder="What is in the image?",
value="What is in the image?"
)
vqa_model = gr.Dropdown(
choices=[
"dandelin/vilt-b32-finetuned-vqa",
"Salesforce/blip-vqa-base"
],
value="dandelin/vilt-b32-finetuned-vqa",
label="Выберите модель"
)
vqa_btn = gr.Button("Ответить", variant="primary")
with gr.Column():
vqa_output = gr.Textbox(label="Ответ", lines=3)
vqa_btn.click(
visual_qa,
inputs=[vqa_image, vqa_question, vqa_model],
outputs=vqa_output
)
with gr.Accordion("Zero-shot классификация", open=False):
with gr.Row():
with gr.Column():
zs_image_input = gr.Image(
label="Загрузите изображение",
type="pil"
)
zs_image_labels = gr.Textbox(
label="Кандидаты (через запятую)",
placeholder="cat, dog, bird",
value="cat, dog, bird"
)
zs_image_model = gr.Dropdown(
choices=[
"openai/clip-vit-base-patch32",
"laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k"
],
value="openai/clip-vit-base-patch32",
label="Выберите модель"
)
zs_image_btn = gr.Button("Классифицировать", variant="primary")
with gr.Column():
zs_image_output = gr.Textbox(label="Результат", lines=6)
zs_image_btn.click(
image_zero_shot_classification,
inputs=[zs_image_input, zs_image_labels, zs_image_model],
outputs=zs_image_output
)
# ========== ВКЛАДКА 4: ИСТОРИЯ ==========
with gr.Tab("📜 История"):
gr.Markdown("## История выполнения моделей")
gr.Markdown("Здесь отображается история всех выполненных операций с моделями. История автоматически обновляется при каждом вызове.")
with gr.Row():
history_clear_btn = gr.Button("Очистить историю", variant="stop")
history_refresh_btn = gr.Button("Обновить", variant="secondary")
history_table = gr.Dataframe(
label="История выполнения",
headers=["Задача", "Модель", "Время выполнения", "Дата/Время", "Входные данные", "Результат"],
value=get_history_display(),
interactive=False,
wrap=True
)
history_json = gr.JSON(
label="Полная история (JSON)",
value=history if history else []
)
def update_history_display():
"""Обновляет отображение истории"""
display_data = get_history_display()
json_data = history if history else []
return display_data, json_data
history_refresh_btn.click(
fn=update_history_display,
outputs=[history_table, history_json]
)
history_clear_btn.click(
fn=clear_history,
outputs=[history_table, history_json]
)
if __name__ == "__main__":
demo.launch(ssr_mode=False)