DINGDINGBELLS commited on
Commit
f988f4f
·
verified ·
1 Parent(s): a99daad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -13
app.py CHANGED
@@ -2,36 +2,46 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  from threading import Thread
 
 
 
 
5
 
6
  MODEL_ID = "."
7
 
 
 
 
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
9
  model = AutoModelForCausalLM.from_pretrained(
10
  MODEL_ID,
11
- torch_dtype=torch.float32,
12
  low_cpu_mem_usage=True,
13
- device_map="cpu"
 
14
  )
15
 
16
  def predict(message, history):
17
- # НИКАКОГО системного промпта в начале.
18
- # Сразу начинаем со структуры диалога.
19
- prompt = ""
 
20
  for msg in history:
21
- # Просто переносим роли, которые понимает модель (User/Bot)
22
  role = "User" if msg["role"] == "user" else "Bot"
23
  prompt += f"{role}: {msg['content']}\n"
24
 
25
- # Добавляем текущий ввод
26
  prompt += f"User: {message}\nBot:"
27
 
28
- inputs = tokenizer(prompt, return_tensors="pt")
29
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
30
 
31
  generate_kwargs = dict(
32
  **inputs,
33
  streamer=streamer,
34
- max_new_tokens=100,
35
  do_sample=True,
36
  temperature=0.7,
37
  repetition_penalty=1.2,
@@ -43,13 +53,18 @@ def predict(message, history):
43
 
44
  partial_message = ""
45
  for new_token in streamer:
46
- # Если модель в порыве шизы начнет писать за "User:", обрезаем
47
  if "User:" in new_token:
48
  break
49
  partial_message += new_token
50
  yield partial_message
51
 
52
- demo = gr.ChatInterface(predict, type="messages")
 
 
 
 
 
53
 
54
  if __name__ == "__main__":
55
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  from threading import Thread
5
+ import os
6
+
7
+ # Папка для экстренного сброса весов, если RAM все равно будет не хватать
8
+ os.makedirs("offload", exist_ok=True)
9
 
10
  MODEL_ID = "."
11
 
12
+ print("🍌 BananaGPT: Загрузка в float16...")
13
+
14
+ # Загружаем токенизатор
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16
+
17
+ # Загружаем модель: float16 режет потребление памяти в 2 раза
18
  model = AutoModelForCausalLM.from_pretrained(
19
  MODEL_ID,
20
+ torch_dtype=torch.float16,
21
  low_cpu_mem_usage=True,
22
+ device_map="auto",
23
+ offload_folder="offload"
24
  )
25
 
26
  def predict(message, history):
27
+ # ПУСТОЙ промпт (никаких системных инструкций, как ты и просил)
28
+ prompt = ""
29
+
30
+ # СТРУКТУРА: переносим историю диалога
31
  for msg in history:
 
32
  role = "User" if msg["role"] == "user" else "Bot"
33
  prompt += f"{role}: {msg['content']}\n"
34
 
35
+ # ЗАПРОС: добавляем текущее сообщение
36
  prompt += f"User: {message}\nBot:"
37
 
38
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
39
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
40
 
41
  generate_kwargs = dict(
42
  **inputs,
43
  streamer=streamer,
44
+ max_new_tokens=128,
45
  do_sample=True,
46
  temperature=0.7,
47
  repetition_penalty=1.2,
 
53
 
54
  partial_message = ""
55
  for new_token in streamer:
 
56
  if "User:" in new_token:
57
  break
58
  partial_message += new_token
59
  yield partial_message
60
 
61
+ # Интерфейс
62
+ demo = gr.ChatInterface(
63
+ predict,
64
+ type="messages",
65
+ title="BananaGPT (float16)"
66
+ )
67
 
68
  if __name__ == "__main__":
69
+ # Запуск на порту 7860 для HF Spaces
70
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)