aciang commited on
Commit
81dfb3b
·
verified ·
1 Parent(s): 4861760

fix: remove concurrency_count; add CPU fallback + cache

Browse files
Files changed (1) hide show
  1. app.py +44 -71
app.py CHANGED
@@ -1,104 +1,77 @@
1
 
2
  import os, torch, gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
4
-
5
- TITLE = "LanguageBridge — Multimodal Chatbot (Mistral-7B)"
6
  MODEL_ID = os.getenv("MODEL_ID", "aciang/mistral7b-tk-sft-20251019-merged")
7
-
8
- SYSTEM_PROMPT = (
9
- "你是『語言橋』學習助教。規則:"
10
- "1) 嚴謹、分段、先重點後細節;"
11
- "2) 若為數學/規則題:先列步驟,再給最終答案;"
12
- "3) 若資訊不足,請明確指出缺口,勿捏造;"
13
- "4) 優先以繁體中文回答。"
14
- )
15
 
16
  def load_llm():
17
- # 4-bit 後援;失敗則 fp16/CPU fallback
18
- kwargs = dict(trust_remote_code=False)
19
- bnb = BitsAndBytesConfig(
20
- load_in_4bit=True, bnb_4bit_quant_type="nf4",
21
- bnb_4bit_use_double_quant=True,
22
- bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
23
- )
24
  try:
25
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, quantization_config=bnb, device_map="auto", **kwargs)
 
 
 
 
 
 
 
 
26
  except Exception as e:
27
- print("[4-bit failed] -> fp16/CPU fallback:", e)
28
  model = AutoModelForCausalLM.from_pretrained(
29
  MODEL_ID,
30
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
31
- device_map="auto" if torch.cuda.is_available() else None,
32
  **kwargs
33
  )
34
  tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
35
  if tok.pad_token is None: tok.pad_token = tok.eos_token
36
  tok.padding_side = "left"
37
- if torch.cuda.is_available():
38
- torch.backends.cuda.matmul.allow_tf32 = True
39
  model.config.use_cache = True
40
  return tok, model
41
 
42
  tokenizer, llm = load_llm(); llm.eval()
43
 
44
- def build_prompt(context, question, longform):
45
- head = SYSTEM_PROMPT
46
- if context.strip():
47
- head += f"\n\n[上下文]\n{context.strip()}"
48
- ask = f"\n\n[問題]\n{question.strip()}\n"
49
- tail = "\n請以條列步驟與小結回覆;若可計算,先算再答。"
50
- if longform:
51
- tail += "\n(長文模式)請分段、標題化、最後給出『摘要重點』。"
52
- return head + ask + tail
53
 
54
  @torch.inference_mode()
55
- def stream_answer(context, question, longform, mx, temp, top_p, rep):
56
- prompt = build_prompt(context, question, longform)
57
  inputs = tokenizer(prompt, return_tensors="pt").to(llm.device)
58
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
59
- gen = dict(
60
- **inputs, streamer=streamer, max_new_tokens=int(mx),
61
- temperature=float(temp), top_p=float(top_p),
62
- repetition_penalty=float(rep),
63
- do_sample=True if float(temp)>0 else False,
64
- eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id
65
- )
66
- import threading
67
- t = threading.Thread(target=llm.generate, kwargs=gen); t.start()
68
  buf=""
69
  for tok in streamer:
70
  buf += tok
71
  yield buf
72
 
73
- def warmup():
74
- try:
75
- _ = list(stream_answer("", "簡述本系統的用途。", False, 96, 0.2, 0.9, 1.05))[-1]
76
- print("[warmup] done")
77
- except Exception as e:
78
- print("[warmup] skip:", e)
79
-
80
  with gr.Blocks(title=TITLE, theme="soft") as demo:
81
- gr.Markdown(f"## {TITLE}\n模型:`{MODEL_ID}`|已啟用:上下文欄位、長文模式、流式輸出、暖機")
82
- with gr.Row():
83
- ctx = gr.Textbox(label="上下文(長文,可空白)", placeholder="選填的背景內容/段落/資料摘錄", lines=6)
84
- with gr.Row():
85
- q = gr.Textbox(label="問題 / 指令", placeholder="請清楚描述你的問題", lines=3)
86
- with gr.Row():
87
- longf = gr.Checkbox(label="長文模式(章節化 + 摘要)", value=True)
88
- with gr.Row():
89
- mx = gr.Slider(128, 1024, value=512, step=32, label="max_new_tokens")
90
- temp = gr.Slider(0.0, 0.8, value=0.2, step=0.05, label="temperature")
91
- top = gr.Slider(0.6, 1.0, value=0.9, step=0.01, label="top_p")
92
- rep = gr.Slider(1.0, 1.3, value=1.05, step=0.01, label="repetition_penalty")
93
- go = gr.Button("送出 🚀", variant="primary")
94
- out = gr.Textbox(label="輸出(流式)", lines=14)
95
- clr = gr.Button("清除")
96
-
97
- go.click(stream_answer, inputs=[ctx,q,longf,mx,temp,top,rep], outputs=out)
98
  clr.click(lambda:"", outputs=out)
99
-
100
- demo.queue(concurrency_count=4, max_size=32, api_open=False)
101
- warmup()
102
 
103
  if __name__ == "__main__":
104
  demo.launch(share=False, server_name="0.0.0.0", server_port=7860, show_error=True)
 
1
 
2
  import os, torch, gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
4
+ os.environ.setdefault("HF_HOME", "/data/.cache")
5
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
6
  MODEL_ID = os.getenv("MODEL_ID", "aciang/mistral7b-tk-sft-20251019-merged")
7
+ TITLE = "LanguageBridge — Multimodal Chatbot (Mistral-7B)"
8
+ SYSTEM_PROMPT = "你是教學助教。先讀【任務】,按【格式】作答;資料不足先列缺口,勿猜測。"
 
 
 
 
 
 
9
 
10
  def load_llm():
11
+ has_cuda = torch.cuda.is_available()
12
+ kwargs = dict(trust_remote_code=False, low_cpu_mem_usage=True)
 
 
 
 
 
13
  try:
14
+ if has_cuda:
15
+ bnb = BitsAndBytesConfig(
16
+ load_in_4bit=True, bnb_4bit_use_double_quant=True,
17
+ bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
18
+ )
19
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", quantization_config=bnb, **kwargs)
20
+ else:
21
+ print("[no CUDA] using CPU fp32")
22
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="cpu", torch_dtype=torch.float32, **kwargs)
23
  except Exception as e:
24
+ print("[loader fallback fp16/cpu]:", e)
25
  model = AutoModelForCausalLM.from_pretrained(
26
  MODEL_ID,
27
+ device_map="auto" if has_cuda else "cpu",
28
+ torch_dtype=torch.float16 if has_cuda else torch.float32,
29
  **kwargs
30
  )
31
  tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
32
  if tok.pad_token is None: tok.pad_token = tok.eos_token
33
  tok.padding_side = "left"
 
 
34
  model.config.use_cache = True
35
  return tok, model
36
 
37
  tokenizer, llm = load_llm(); llm.eval()
38
 
39
+ def build_prompt(task, ctx=None):
40
+ head = "你是教學助教。先讀任務,依:1) 摘要要點;2) 逐步推理;3) 結論條列。\n\n"
41
+ if ctx:
42
+ ctx = ctx[-6000:]
43
+ return f"{head}【參考上下文】\n{ctx}\n\n【使用者問題】\n{task}\n\n【回答】"
44
+ return f"{head}【使用者問題】\n{task}\n\n【回答】"
 
 
 
45
 
46
  @torch.inference_mode()
47
+ def stream_answer(task, context, mx=256, temp=0.15, top_p=0.9):
48
+ prompt = build_prompt(task, context.strip() or None)
49
  inputs = tokenizer(prompt, return_tensors="pt").to(llm.device)
50
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
51
+ kwargs = dict(**inputs, streamer=streamer, max_new_tokens=int(mx),
52
+ temperature=float(temp), top_p=float(top_p),
53
+ do_sample=True if float(temp)>0 else False,
54
+ eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id)
55
+ import threading; threading.Thread(target=llm.generate, kwargs=kwargs).start()
 
 
 
 
56
  buf=""
57
  for tok in streamer:
58
  buf += tok
59
  yield buf
60
 
 
 
 
 
 
 
 
61
  with gr.Blocks(title=TITLE, theme="soft") as demo:
62
+ gr.Markdown(f"## {TITLE}|模型:`{MODEL_ID}`(流式)")
63
+ q = gr.Textbox(label="你的問題 / 指令", lines=5, placeholder="可貼長文;我會先摘要→推理→結論")
64
+ ctx = gr.Textbox(label="(可選)上下文", lines=6)
65
+ mx = gr.Slider(64, 512, value=256, step=32, label="max_new_tokens")
66
+ temp = gr.Slider(0.0, 0.8, value=0.15, step=0.05, label="temperature")
67
+ top = gr.Slider(0.6, 1.0, value=0.9, step=0.01, label="top_p")
68
+ go = gr.Button("送出 🚀", variant="primary")
69
+ out = gr.Textbox(label="輸出(流式)", lines=14)
70
+ clr = gr.Button("清除")
71
+ go.click(stream_answer, inputs=[q, ctx, mx, temp, top], outputs=out)
 
 
 
 
 
 
 
72
  clr.click(lambda:"", outputs=out)
73
+ # ← 修正:不要用舊參數 concurrency_count
74
+ demo.queue(max_size=32, status_update_rate=1, api_open=False)
 
75
 
76
  if __name__ == "__main__":
77
  demo.launch(share=False, server_name="0.0.0.0", server_port=7860, show_error=True)