Spaces:
Running
Running
File size: 3,772 Bytes
2b6f383 dc388a7 44106b5 dc388a7 023664a f988f4f 2b6f383 44106b5 023664a 44106b5 2b6f383 1ebafa7 73501d2 1ebafa7 20736ba cde2cb5 44106b5 023664a 009ae89 44106b5 dc388a7 f2aaed6 2b6f383 44106b5 2b6f383 023664a 44106b5 023664a 44106b5 009ae89 023664a 1ebafa7 44106b5 1ebafa7 023664a 009ae89 1ebafa7 009ae89 da0f0d3 023664a 009ae89 023664a 009ae89 023664a 73501d2 44106b5 73501d2 44106b5 36299ab 009ae89 023664a 36299ab 2b6f383 009ae89 2b6f383 009ae89 44106b5 023664a 44106b5 009ae89 44106b5 da0f0d3 023664a dea6353 009ae89 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | import gradio as gr
import torch
import re
import random
import requests
from transformers import AutoModelForCausalLM, AutoTokenizer
from torchvision import models, transforms
from PIL import Image
# ==========================================
# 1. ЗАГРУЗКА ЗРЕНИЯ (~20MB)
# ==========================================
vision_model = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1).eval()
LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
labels = requests.get(LABELS_URL).text.splitlines()
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# ==========================================
# 2. ТВОИ МОЗГИ (Лимит 340MB)
# ==========================================
MODEL_PATH = "./"
TOKENIZER_NAME = "sberbank-ai/rugpt3small_based_on_gpt2"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
dtype=torch.float32,
tie_word_embeddings=False
).to("cpu")
model.config.max_position_embeddings = 128
# ==========================================
# 3. ЛОГИКА
# ==========================================
def predict(image, message, history):
vision_info = "ничего не вижу"
if image is not None:
try:
pil_img = Image.fromarray(image.astype('uint8'), 'RGB')
input_tensor = preprocess(pil_img).unsqueeze(0)
with torch.no_grad():
output = vision_model(input_tensor)
_, index = torch.max(output, 1)
detected = labels[index[0]].replace("_", " ")
vision_info = f"вижу {detected}"
except Exception:
vision_info = "туман"
# Собираем промпт
prompt = f"User: ({vision_info}) {message}\nBot:"
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
curr_len = inputs.input_ids.shape[1]
# Лимит до 128 токенов
max_to_gen = 128 - curr_len - 1
if max_to_gen <= 5:
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": "Слишком длинно, не влезаю в 128!"})
return history
with torch.no_grad():
output_tokens = model.generate(
**inputs,
max_new_tokens=max_to_gen,
do_sample=True,
temperature=0.35,
repetition_penalty=1.8,
pad_token_id=tokenizer.pad_token_id
)
answer = tokenizer.decode(output_tokens[0][curr_len:], skip_special_tokens=True).strip()
answer = re.split(r'User:|Bot:|\n', answer)[0].strip()
if not answer: answer = "..."
# В Gradio 6.0 возвращаем обновленный список сообщений
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": answer})
return history
# ==========================================
# 4. ИНТЕРФЕЙС (GRADIO 6.0)
# ==========================================
with gr.Blocks() as demo:
gr.Markdown("# 🍌 BananaVision Lite")
with gr.Row():
img_input = gr.Image(label="Глаза")
chatbot = gr.Chatbot(label="Чат") # БЕЗ type="messages"
msg = gr.Textbox(placeholder="Чё там на картинке?")
btn = gr.Button("Спросить")
btn.click(predict, [img_input, msg, chatbot], [chatbot])
msg.submit(predict, [img_input, msg, chatbot], [chatbot])
# Тема передается здесь
demo.launch(theme=gr.themes.Default(primary_hue="yellow")) |