import gradio as gr import torch import re import random import requests from transformers import AutoModelForCausalLM, AutoTokenizer from torchvision import models, transforms from PIL import Image # ========================================== # 1. ЗАГРУЗКА ЗРЕНИЯ (~20MB) # ========================================== vision_model = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1).eval() LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" labels = requests.get(LABELS_URL).text.splitlines() preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ========================================== # 2. ТВОИ МОЗГИ (Лимит 340MB) # ========================================== MODEL_PATH = "./" TOKENIZER_NAME = "sberbank-ai/rugpt3small_based_on_gpt2" tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, dtype=torch.float32, tie_word_embeddings=False ).to("cpu") model.config.max_position_embeddings = 128 # ========================================== # 3. ЛОГИКА # ========================================== def predict(image, message, history): vision_info = "ничего не вижу" if image is not None: try: pil_img = Image.fromarray(image.astype('uint8'), 'RGB') input_tensor = preprocess(pil_img).unsqueeze(0) with torch.no_grad(): output = vision_model(input_tensor) _, index = torch.max(output, 1) detected = labels[index[0]].replace("_", " ") vision_info = f"вижу {detected}" except Exception: vision_info = "туман" # Собираем промпт prompt = f"User: ({vision_info}) {message}\nBot:" inputs = tokenizer(prompt, return_tensors="pt").to("cpu") curr_len = inputs.input_ids.shape[1] # Лимит до 128 токенов max_to_gen = 128 - curr_len - 1 if max_to_gen <= 5: history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": "Слишком длинно, не влезаю в 128!"}) return history with torch.no_grad(): output_tokens = model.generate( **inputs, max_new_tokens=max_to_gen, do_sample=True, temperature=0.35, repetition_penalty=1.8, pad_token_id=tokenizer.pad_token_id ) answer = tokenizer.decode(output_tokens[0][curr_len:], skip_special_tokens=True).strip() answer = re.split(r'User:|Bot:|\n', answer)[0].strip() if not answer: answer = "..." # В Gradio 6.0 возвращаем обновленный список сообщений history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": answer}) return history # ========================================== # 4. ИНТЕРФЕЙС (GRADIO 6.0) # ========================================== with gr.Blocks() as demo: gr.Markdown("# 🍌 BananaVision Lite") with gr.Row(): img_input = gr.Image(label="Глаза") chatbot = gr.Chatbot(label="Чат") # БЕЗ type="messages" msg = gr.Textbox(placeholder="Чё там на картинке?") btn = gr.Button("Спросить") btn.click(predict, [img_input, msg, chatbot], [chatbot]) msg.submit(predict, [img_input, msg, chatbot], [chatbot]) # Тема передается здесь demo.launch(theme=gr.themes.Default(primary_hue="yellow"))