File size: 3,772 Bytes
2b6f383
dc388a7
 
 
44106b5
dc388a7
023664a
 
f988f4f
2b6f383
44106b5
023664a
 
 
 
 
 
 
 
 
 
 
 
 
 
44106b5
2b6f383
1ebafa7
 
73501d2
1ebafa7
20736ba
cde2cb5
44106b5
023664a
009ae89
44106b5
dc388a7
f2aaed6
2b6f383
44106b5
2b6f383
023664a
44106b5
023664a
 
44106b5
 
 
 
 
 
 
 
 
 
 
009ae89
023664a
1ebafa7
44106b5
1ebafa7
023664a
009ae89
1ebafa7
009ae89
 
 
 
da0f0d3
023664a
 
 
009ae89
023664a
009ae89
023664a
 
 
73501d2
44106b5
 
73501d2
44106b5
36299ab
009ae89
023664a
 
 
36299ab
2b6f383
009ae89
2b6f383
009ae89
44106b5
023664a
44106b5
009ae89
44106b5
 
 
da0f0d3
023664a
 
dea6353
009ae89
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import torch
import re
import random
import requests
from transformers import AutoModelForCausalLM, AutoTokenizer
from torchvision import models, transforms
from PIL import Image

# ==========================================
# 1. ЗАГРУЗКА ЗРЕНИЯ (~20MB)
# ==========================================
vision_model = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1).eval()

LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
labels = requests.get(LABELS_URL).text.splitlines()

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ==========================================
# 2. ТВОИ МОЗГИ (Лимит 340MB)
# ==========================================
MODEL_PATH = "./" 
TOKENIZER_NAME = "sberbank-ai/rugpt3small_based_on_gpt2"

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    dtype=torch.float32, 
    tie_word_embeddings=False
).to("cpu")

model.config.max_position_embeddings = 128

# ==========================================
# 3. ЛОГИКА
# ==========================================
def predict(image, message, history):
    vision_info = "ничего не вижу"
    
    if image is not None:
        try:
            pil_img = Image.fromarray(image.astype('uint8'), 'RGB')
            input_tensor = preprocess(pil_img).unsqueeze(0)
            with torch.no_grad():
                output = vision_model(input_tensor)
            _, index = torch.max(output, 1)
            detected = labels[index[0]].replace("_", " ")
            vision_info = f"вижу {detected}"
        except Exception:
            vision_info = "туман"

    # Собираем промпт
    prompt = f"User: ({vision_info}) {message}\nBot:"
    
    inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
    curr_len = inputs.input_ids.shape[1]
    
    # Лимит до 128 токенов
    max_to_gen = 128 - curr_len - 1
    if max_to_gen <= 5:
        history.append({"role": "user", "content": message})
        history.append({"role": "assistant", "content": "Слишком длинно, не влезаю в 128!"})
        return history

    with torch.no_grad():
        output_tokens = model.generate(
            **inputs,
            max_new_tokens=max_to_gen,
            do_sample=True,
            temperature=0.35,
            repetition_penalty=1.8,
            pad_token_id=tokenizer.pad_token_id
        )

    answer = tokenizer.decode(output_tokens[0][curr_len:], skip_special_tokens=True).strip()
    answer = re.split(r'User:|Bot:|\n', answer)[0].strip()

    if not answer: answer = "..."

    # В Gradio 6.0 возвращаем обновленный список сообщений
    history.append({"role": "user", "content": message})
    history.append({"role": "assistant", "content": answer})
    return history

# ==========================================
# 4. ИНТЕРФЕЙС (GRADIO 6.0)
# ==========================================
with gr.Blocks() as demo:
    gr.Markdown("# 🍌 BananaVision Lite")
    with gr.Row():
        img_input = gr.Image(label="Глаза")
        chatbot = gr.Chatbot(label="Чат") # БЕЗ type="messages"
    
    msg = gr.Textbox(placeholder="Чё там на картинке?")
    btn = gr.Button("Спросить")

    btn.click(predict, [img_input, msg, chatbot], [chatbot])
    msg.submit(predict, [img_input, msg, chatbot], [chatbot])

# Тема передается здесь
demo.launch(theme=gr.themes.Default(primary_hue="yellow"))