clicklezGPT / app.py
DINGDINGBELLS's picture
Update app.py
009ae89 verified
import gradio as gr
import torch
import re
import random
import requests
from transformers import AutoModelForCausalLM, AutoTokenizer
from torchvision import models, transforms
from PIL import Image
# ==========================================
# 1. ЗАГРУЗКА ЗРЕНИЯ (~20MB)
# ==========================================
vision_model = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1).eval()
LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
labels = requests.get(LABELS_URL).text.splitlines()
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# ==========================================
# 2. ТВОИ МОЗГИ (Лимит 340MB)
# ==========================================
MODEL_PATH = "./"
TOKENIZER_NAME = "sberbank-ai/rugpt3small_based_on_gpt2"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
dtype=torch.float32,
tie_word_embeddings=False
).to("cpu")
model.config.max_position_embeddings = 128
# ==========================================
# 3. ЛОГИКА
# ==========================================
def predict(image, message, history):
vision_info = "ничего не вижу"
if image is not None:
try:
pil_img = Image.fromarray(image.astype('uint8'), 'RGB')
input_tensor = preprocess(pil_img).unsqueeze(0)
with torch.no_grad():
output = vision_model(input_tensor)
_, index = torch.max(output, 1)
detected = labels[index[0]].replace("_", " ")
vision_info = f"вижу {detected}"
except Exception:
vision_info = "туман"
# Собираем промпт
prompt = f"User: ({vision_info}) {message}\nBot:"
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
curr_len = inputs.input_ids.shape[1]
# Лимит до 128 токенов
max_to_gen = 128 - curr_len - 1
if max_to_gen <= 5:
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": "Слишком длинно, не влезаю в 128!"})
return history
with torch.no_grad():
output_tokens = model.generate(
**inputs,
max_new_tokens=max_to_gen,
do_sample=True,
temperature=0.35,
repetition_penalty=1.8,
pad_token_id=tokenizer.pad_token_id
)
answer = tokenizer.decode(output_tokens[0][curr_len:], skip_special_tokens=True).strip()
answer = re.split(r'User:|Bot:|\n', answer)[0].strip()
if not answer: answer = "..."
# В Gradio 6.0 возвращаем обновленный список сообщений
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": answer})
return history
# ==========================================
# 4. ИНТЕРФЕЙС (GRADIO 6.0)
# ==========================================
with gr.Blocks() as demo:
gr.Markdown("# 🍌 BananaVision Lite")
with gr.Row():
img_input = gr.Image(label="Глаза")
chatbot = gr.Chatbot(label="Чат") # БЕЗ type="messages"
msg = gr.Textbox(placeholder="Чё там на картинке?")
btn = gr.Button("Спросить")
btn.click(predict, [img_input, msg, chatbot], [chatbot])
msg.submit(predict, [img_input, msg, chatbot], [chatbot])
# Тема передается здесь
demo.launch(theme=gr.themes.Default(primary_hue="yellow"))