Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import re | |
| import random | |
| import requests | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| # ========================================== | |
| # 1. ЗАГРУЗКА ЗРЕНИЯ (~20MB) | |
| # ========================================== | |
| vision_model = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1).eval() | |
| LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" | |
| labels = requests.get(LABELS_URL).text.splitlines() | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # ========================================== | |
| # 2. ТВОИ МОЗГИ (Лимит 340MB) | |
| # ========================================== | |
| MODEL_PATH = "./" | |
| TOKENIZER_NAME = "sberbank-ai/rugpt3small_based_on_gpt2" | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_PATH, | |
| dtype=torch.float32, | |
| tie_word_embeddings=False | |
| ).to("cpu") | |
| model.config.max_position_embeddings = 128 | |
| # ========================================== | |
| # 3. ЛОГИКА | |
| # ========================================== | |
| def predict(image, message, history): | |
| vision_info = "ничего не вижу" | |
| if image is not None: | |
| try: | |
| pil_img = Image.fromarray(image.astype('uint8'), 'RGB') | |
| input_tensor = preprocess(pil_img).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = vision_model(input_tensor) | |
| _, index = torch.max(output, 1) | |
| detected = labels[index[0]].replace("_", " ") | |
| vision_info = f"вижу {detected}" | |
| except Exception: | |
| vision_info = "туман" | |
| # Собираем промпт | |
| prompt = f"User: ({vision_info}) {message}\nBot:" | |
| inputs = tokenizer(prompt, return_tensors="pt").to("cpu") | |
| curr_len = inputs.input_ids.shape[1] | |
| # Лимит до 128 токенов | |
| max_to_gen = 128 - curr_len - 1 | |
| if max_to_gen <= 5: | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": "Слишком длинно, не влезаю в 128!"}) | |
| return history | |
| with torch.no_grad(): | |
| output_tokens = model.generate( | |
| **inputs, | |
| max_new_tokens=max_to_gen, | |
| do_sample=True, | |
| temperature=0.35, | |
| repetition_penalty=1.8, | |
| pad_token_id=tokenizer.pad_token_id | |
| ) | |
| answer = tokenizer.decode(output_tokens[0][curr_len:], skip_special_tokens=True).strip() | |
| answer = re.split(r'User:|Bot:|\n', answer)[0].strip() | |
| if not answer: answer = "..." | |
| # В Gradio 6.0 возвращаем обновленный список сообщений | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": answer}) | |
| return history | |
| # ========================================== | |
| # 4. ИНТЕРФЕЙС (GRADIO 6.0) | |
| # ========================================== | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🍌 BananaVision Lite") | |
| with gr.Row(): | |
| img_input = gr.Image(label="Глаза") | |
| chatbot = gr.Chatbot(label="Чат") # БЕЗ type="messages" | |
| msg = gr.Textbox(placeholder="Чё там на картинке?") | |
| btn = gr.Button("Спросить") | |
| btn.click(predict, [img_input, msg, chatbot], [chatbot]) | |
| msg.submit(predict, [img_input, msg, chatbot], [chatbot]) | |
| # Тема передается здесь | |
| demo.launch(theme=gr.themes.Default(primary_hue="yellow")) |