DINGDINGBELLS commited on
Commit
009ae89
·
verified ·
1 Parent(s): 44106b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -10,7 +10,6 @@ from PIL import Image
10
  # ==========================================
11
  # 1. ЗАГРУЗКА ЗРЕНИЯ (~20MB)
12
  # ==========================================
13
- print("--- Загрузка SqueezeNet ---")
14
  vision_model = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1).eval()
15
 
16
  LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
@@ -29,15 +28,12 @@ preprocess = transforms.Compose([
29
  MODEL_PATH = "./"
30
  TOKENIZER_NAME = "sberbank-ai/rugpt3small_based_on_gpt2"
31
 
32
- print("--- Загрузка твоей модели ---")
33
  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
34
-
35
- # Убираем device_map, чтобы не требовать accelerate, и фиксим dtype
36
  model = AutoModelForCausalLM.from_pretrained(
37
  MODEL_PATH,
38
  dtype=torch.float32,
39
  tie_word_embeddings=False
40
- ).to("cpu") # Явно отправляем на CPU
41
 
42
  model.config.max_position_embeddings = 128
43
 
@@ -49,7 +45,6 @@ def predict(image, message, history):
49
 
50
  if image is not None:
51
  try:
52
- # Gradio может давать массив numpy, переводим в PIL
53
  pil_img = Image.fromarray(image.astype('uint8'), 'RGB')
54
  input_tensor = preprocess(pil_img).unsqueeze(0)
55
  with torch.no_grad():
@@ -60,22 +55,25 @@ def predict(image, message, history):
60
  except Exception:
61
  vision_info = "туман"
62
 
63
- # Промпт под твою структуру
64
  prompt = f"User: ({vision_info}) {message}\nBot:"
65
 
66
  inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
67
  curr_len = inputs.input_ids.shape[1]
68
 
 
69
  max_to_gen = 128 - curr_len - 1
70
- if max_to_gen <= 2:
71
- return history + [{"role": "assistant", "content": "Слишком много инфы, я запутался!"}]
 
 
72
 
73
  with torch.no_grad():
74
  output_tokens = model.generate(
75
  **inputs,
76
- max_new_tokens=max_new_tokens,
77
  do_sample=True,
78
- temperature=0.25,
79
  repetition_penalty=1.8,
80
  pad_token_id=tokenizer.pad_token_id
81
  )
@@ -85,18 +83,19 @@ def predict(image, message, history):
85
 
86
  if not answer: answer = "..."
87
 
 
88
  history.append({"role": "user", "content": message})
89
  history.append({"role": "assistant", "content": answer})
90
  return history
91
 
92
  # ==========================================
93
- # 4. ИНТЕРФЕЙС
94
  # ==========================================
95
- with gr.Blocks(theme=gr.themes.Default(primary_hue="yellow")) as demo:
96
  gr.Markdown("# 🍌 BananaVision Lite")
97
  with gr.Row():
98
  img_input = gr.Image(label="Глаза")
99
- chatbot = gr.Chatbot(type="messages", label="Чат")
100
 
101
  msg = gr.Textbox(placeholder="Чё там на картинке?")
102
  btn = gr.Button("Спросить")
@@ -104,4 +103,5 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="yellow")) as demo:
104
  btn.click(predict, [img_input, msg, chatbot], [chatbot])
105
  msg.submit(predict, [img_input, msg, chatbot], [chatbot])
106
 
107
- demo.launch()
 
 
10
  # ==========================================
11
  # 1. ЗАГРУЗКА ЗРЕНИЯ (~20MB)
12
  # ==========================================
 
13
  vision_model = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1).eval()
14
 
15
  LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
 
28
  MODEL_PATH = "./"
29
  TOKENIZER_NAME = "sberbank-ai/rugpt3small_based_on_gpt2"
30
 
 
31
  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
 
 
32
  model = AutoModelForCausalLM.from_pretrained(
33
  MODEL_PATH,
34
  dtype=torch.float32,
35
  tie_word_embeddings=False
36
+ ).to("cpu")
37
 
38
  model.config.max_position_embeddings = 128
39
 
 
45
 
46
  if image is not None:
47
  try:
 
48
  pil_img = Image.fromarray(image.astype('uint8'), 'RGB')
49
  input_tensor = preprocess(pil_img).unsqueeze(0)
50
  with torch.no_grad():
 
55
  except Exception:
56
  vision_info = "туман"
57
 
58
+ # Собираем промпт
59
  prompt = f"User: ({vision_info}) {message}\nBot:"
60
 
61
  inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
62
  curr_len = inputs.input_ids.shape[1]
63
 
64
+ # Лимит до 128 токенов
65
  max_to_gen = 128 - curr_len - 1
66
+ if max_to_gen <= 5:
67
+ history.append({"role": "user", "content": message})
68
+ history.append({"role": "assistant", "content": "Слишком длинно, не влезаю в 128!"})
69
+ return history
70
 
71
  with torch.no_grad():
72
  output_tokens = model.generate(
73
  **inputs,
74
+ max_new_tokens=max_to_gen,
75
  do_sample=True,
76
+ temperature=0.35,
77
  repetition_penalty=1.8,
78
  pad_token_id=tokenizer.pad_token_id
79
  )
 
83
 
84
  if not answer: answer = "..."
85
 
86
+ # В Gradio 6.0 возвращаем обновленный список сообщений
87
  history.append({"role": "user", "content": message})
88
  history.append({"role": "assistant", "content": answer})
89
  return history
90
 
91
  # ==========================================
92
+ # 4. ИНТЕРФЕЙС (GRADIO 6.0)
93
  # ==========================================
94
+ with gr.Blocks() as demo:
95
  gr.Markdown("# 🍌 BananaVision Lite")
96
  with gr.Row():
97
  img_input = gr.Image(label="Глаза")
98
+ chatbot = gr.Chatbot(label="Чат") # БЕЗ type="messages"
99
 
100
  msg = gr.Textbox(placeholder="Чё там на картинке?")
101
  btn = gr.Button("Спросить")
 
103
  btn.click(predict, [img_input, msg, chatbot], [chatbot])
104
  msg.submit(predict, [img_input, msg, chatbot], [chatbot])
105
 
106
+ # Тема передается здесь
107
+ demo.launch(theme=gr.themes.Default(primary_hue="yellow"))