DINGDINGBELLS commited on
Commit
44106b5
·
verified ·
1 Parent(s): 828a8c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -47
app.py CHANGED
@@ -2,22 +2,20 @@ import gradio as gr
2
  import torch
3
  import re
4
  import random
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from torchvision import models, transforms
7
  from PIL import Image
8
- import requests
9
 
10
  # ==========================================
11
- # 1. ЗАГРУЗКА ЗРЕНИЯ (ImageNet Classifier)
12
  # ==========================================
13
- # SqueezeNet — весит копейки, работает быстро
14
  vision_model = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1).eval()
15
 
16
- # Подгружаем названия категорий
17
  LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
18
  labels = requests.get(LABELS_URL).text.splitlines()
19
 
20
- # Подготовка картинки
21
  preprocess = transforms.Compose([
22
  transforms.Resize(256),
23
  transforms.CenterCrop(224),
@@ -26,87 +24,82 @@ preprocess = transforms.Compose([
26
  ])
27
 
28
  # ==========================================
29
- # 2. ТВОИ МОЗГИ (200M Model)
30
  # ==========================================
31
  MODEL_PATH = "./"
32
  TOKENIZER_NAME = "sberbank-ai/rugpt3small_based_on_gpt2"
33
 
 
34
  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
 
 
35
  model = AutoModelForCausalLM.from_pretrained(
36
  MODEL_PATH,
37
- torch_dtype=torch.float32,
38
- device_map="cpu",
39
  tie_word_embeddings=False
40
- )
 
41
  model.config.max_position_embeddings = 128
42
 
43
  # ==========================================
44
- # 3. ФУНКЦИЯ ПРЕДСКАЗАНИЯ
45
  # ==========================================
46
  def predict(image, message, history):
47
- vision_info = ""
48
 
49
- # Если закинули картинку — распознаем
50
  if image is not None:
51
- pil_img = Image.fromarray(image.astype('uint8'), 'RGB')
52
- input_tensor = preprocess(pil_img).unsqueeze(0)
53
- with torch.no_grad():
54
- output = vision_model(input_tensor)
55
-
56
- # Берем самый вероятный объект
57
- _, index = torch.max(output, 1)
58
- detected = labels[index[0]].replace("_", " ")
59
- vision_info = f"Ты видишь перед собой: {detected}."
60
-
61
- # Собираем промпт. Впихиваем зрение в начало, чтобы модель "прозрела"
62
- # Формат: User: (Вижу: банан) Чё это? \n Bot:
 
63
  prompt = f"User: ({vision_info}) {message}\nBot:"
64
 
65
- inputs = tokenizer(prompt, return_tensors="pt")
66
  curr_len = inputs.input_ids.shape[1]
67
 
68
- # Лимит 128
69
  max_to_gen = 128 - curr_len - 1
70
  if max_to_gen <= 2:
71
- return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "Память забита!"}]
72
 
73
  with torch.no_grad():
74
  output_tokens = model.generate(
75
  **inputs,
76
- max_new_tokens=max_to_gen,
77
  do_sample=True,
78
- temperature=0.35, # Твоя ТЕМПЕРАТУРА
79
  repetition_penalty=1.8,
80
  pad_token_id=tokenizer.pad_token_id
81
  )
82
 
83
- raw_answer = tokenizer.decode(output_tokens[0][curr_len:], skip_special_tokens=True).strip()
84
- answer = re.split(r'User:|Bot:|\n', raw_answer)[0].strip()
85
 
86
- if not answer: answer = "Ясно."
87
 
88
- # Формат Gradio 6.0
89
  history.append({"role": "user", "content": message})
90
  history.append({"role": "assistant", "content": answer})
91
  return history
92
 
93
  # ==========================================
94
- # 4. ИНТЕРФЕЙС (DARK-YELLOW STYLE)
95
  # ==========================================
96
- with gr.Blocks(theme=gr.themes.Default(primary_hue="yellow", secondary_hue="neutral").set(
97
- body_background_fill="#000000",
98
- block_background_fill="#111111",
99
- input_background_fill="#222222"
100
- )) as demo:
101
- gr.Markdown("# 🍌 **BananaVision Lite** (Limit: 340MB)")
102
-
103
  with gr.Row():
104
- with gr.Column(scale=1):
105
- img_input = gr.Image(label="Глаза бота (Camera/Upload)")
106
- with gr.Column(scale=2):
107
- chatbot = gr.Chatbot(type="messages", label="Чат с Хамом")
108
- msg = gr.Textbox(placeholder="Спроси чё-нибудь про картинку...")
109
- btn = gr.Button("Отправить", variant="primary")
110
 
111
  btn.click(predict, [img_input, msg, chatbot], [chatbot])
112
  msg.submit(predict, [img_input, msg, chatbot], [chatbot])
 
2
  import torch
3
  import re
4
  import random
5
+ import requests
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from torchvision import models, transforms
8
  from PIL import Image
 
9
 
10
  # ==========================================
11
+ # 1. ЗАГРУЗКА ЗРЕНИЯ (~20MB)
12
  # ==========================================
13
+ print("--- Загрузка SqueezeNet ---")
14
  vision_model = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1).eval()
15
 
 
16
  LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
17
  labels = requests.get(LABELS_URL).text.splitlines()
18
 
 
19
  preprocess = transforms.Compose([
20
  transforms.Resize(256),
21
  transforms.CenterCrop(224),
 
24
  ])
25
 
26
  # ==========================================
27
+ # 2. ТВОИ МОЗГИ (Лимит 340MB)
28
  # ==========================================
29
  MODEL_PATH = "./"
30
  TOKENIZER_NAME = "sberbank-ai/rugpt3small_based_on_gpt2"
31
 
32
+ print("--- Загрузка твоей модели ---")
33
  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
34
+
35
+ # Убираем device_map, чтобы не требовать accelerate, и фиксим dtype
36
  model = AutoModelForCausalLM.from_pretrained(
37
  MODEL_PATH,
38
+ dtype=torch.float32,
 
39
  tie_word_embeddings=False
40
+ ).to("cpu") # Явно отправляем на CPU
41
+
42
  model.config.max_position_embeddings = 128
43
 
44
  # ==========================================
45
+ # 3. ЛОГИКА
46
  # ==========================================
47
  def predict(image, message, history):
48
+ vision_info = "ничего не вижу"
49
 
 
50
  if image is not None:
51
+ try:
52
+ # Gradio может давать массив numpy, переводим в PIL
53
+ pil_img = Image.fromarray(image.astype('uint8'), 'RGB')
54
+ input_tensor = preprocess(pil_img).unsqueeze(0)
55
+ with torch.no_grad():
56
+ output = vision_model(input_tensor)
57
+ _, index = torch.max(output, 1)
58
+ detected = labels[index[0]].replace("_", " ")
59
+ vision_info = f"вижу {detected}"
60
+ except Exception:
61
+ vision_info = "туман"
62
+
63
+ # Промпт под твою структуру
64
  prompt = f"User: ({vision_info}) {message}\nBot:"
65
 
66
+ inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
67
  curr_len = inputs.input_ids.shape[1]
68
 
 
69
  max_to_gen = 128 - curr_len - 1
70
  if max_to_gen <= 2:
71
+ return history + [{"role": "assistant", "content": "Слишком много инфы, я запутался!"}]
72
 
73
  with torch.no_grad():
74
  output_tokens = model.generate(
75
  **inputs,
76
+ max_new_tokens=max_new_tokens,
77
  do_sample=True,
78
+ temperature=0.25,
79
  repetition_penalty=1.8,
80
  pad_token_id=tokenizer.pad_token_id
81
  )
82
 
83
+ answer = tokenizer.decode(output_tokens[0][curr_len:], skip_special_tokens=True).strip()
84
+ answer = re.split(r'User:|Bot:|\n', answer)[0].strip()
85
 
86
+ if not answer: answer = "..."
87
 
 
88
  history.append({"role": "user", "content": message})
89
  history.append({"role": "assistant", "content": answer})
90
  return history
91
 
92
  # ==========================================
93
+ # 4. ИНТЕРФЕЙС
94
  # ==========================================
95
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="yellow")) as demo:
96
+ gr.Markdown("# 🍌 BananaVision Lite")
 
 
 
 
 
97
  with gr.Row():
98
+ img_input = gr.Image(label="Глаза")
99
+ chatbot = gr.Chatbot(type="messages", label="Чат")
100
+
101
+ msg = gr.Textbox(placeholder="Чё там на картинке?")
102
+ btn = gr.Button("Спросить")
 
103
 
104
  btn.click(predict, [img_input, msg, chatbot], [chatbot])
105
  msg.submit(predict, [img_input, msg, chatbot], [chatbot])