aciang commited on
Commit
4861760
·
verified ·
1 Parent(s): d768453

update app.py (context+longform+warmup+4bit fallback)

Browse files
Files changed (1) hide show
  1. app.py +73 -55
app.py CHANGED
@@ -1,86 +1,104 @@
1
 
2
- import os, time, torch, gradio as gr
3
- os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # 加速首次下載
4
-
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
6
 
7
- TITLE = os.getenv("SPACE_TITLE", "LanguageBridge — Multimodal Chatbot (Mistral-7B)")
8
  MODEL_ID = os.getenv("MODEL_ID", "aciang/mistral7b-tk-sft-20251019-merged")
9
 
10
  SYSTEM_PROMPT = (
11
- "你是『語言橋』助教。回答原則:條列、準確、可重現步驟;不足處要誠實說明。"
 
 
 
 
12
  )
13
 
14
- _tok, _llm = None, None
15
  def load_llm():
16
- global _tok, _llm
17
- if _llm is not None:
18
- return _tok, _llm
19
- # 4-bit(失敗則自動回退)
20
  bnb = BitsAndBytesConfig(
21
  load_in_4bit=True, bnb_4bit_quant_type="nf4",
22
  bnb_4bit_use_double_quant=True,
23
  bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
24
  )
25
- kwargs = dict(device_map="auto", trust_remote_code=False, quantization_config=bnb)
26
  try:
27
- _llm = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
28
  except Exception as e:
29
- print("[4-bit failed] fallback:", e)
30
- _llm = AutoModelForCausalLM.from_pretrained(
31
  MODEL_ID,
32
- torch_dtype=(torch.float16 if torch.cuda.is_available() else torch.float32),
33
- device_map=("auto" if torch.cuda.is_available() else None),
34
- trust_remote_code=False
35
  )
36
- _tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
37
- if _tok.pad_token is None: _tok.pad_token = _tok.eos_token
38
- _tok.padding_side = "left"
39
- if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True
40
- _llm.config.use_cache = True
41
- return _tok, _llm
42
-
43
- def format_prompt(user_text:str)->str:
44
- return f"{SYSTEM_PROMPT}\n\n使用者:{user_text}\n助教:"
45
 
46
- @torch.inference_mode()
47
- def generate(user_text, mx=256, temp=0.6, top_p=0.95):
48
- global _tok, _llm
49
- if _llm is None:
50
- yield "(正在載入模型,首次需要數十秒到數分鐘,請稍候…)"
51
- _tok, _llm = load_llm()
52
- yield "(模型載入完成,開始回應…)"
53
 
54
- prompt = format_prompt(user_text)
55
- inputs = _tok(prompt, return_tensors="pt").to(_llm.device)
 
 
 
 
 
 
 
56
 
57
- streamer = TextIteratorStreamer(_tok, skip_prompt=True, skip_special_tokens=True)
58
- gen = dict(**inputs, streamer=streamer, max_new_tokens=int(mx),
59
- temperature=float(temp), top_p=float(top_p),
60
- do_sample=True, eos_token_id=_tok.eos_token_id, pad_token_id=_tok.pad_token_id)
 
 
 
 
 
 
 
 
61
  import threading
62
- t = threading.Thread(target=_llm.generate, kwargs=gen); t.start()
63
-
64
- buf = ""
65
  for tok in streamer:
66
  buf += tok
67
  yield buf
68
 
69
- with gr.Blocks(title=TITLE, fill_height=True) as demo:
70
- gr.Markdown(f"## {TITLE}\n模型:`{MODEL_ID}`(延遲載入)")
71
- chat_in = gr.Textbox(label="你的問題 / 指令", placeholder="輸入文字…", lines=4)
 
 
 
 
 
 
 
 
 
 
 
 
72
  with gr.Row():
73
- mx = gr.Slider(64, 1024, value=256, step=32, label="max_new_tokens")
74
- temp = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="temperature")
75
- top = gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="top_p")
76
- go = gr.Button("送出 🚀", variant="primary")
77
- out = gr.Textbox(label="輸出(流式)", lines=18)
78
- clr = gr.Button("清除")
 
79
 
80
- go.click(generate, inputs=[chat_in, mx, temp, top], outputs=out)
81
- clr.click(lambda: "", outputs=out)
82
 
83
- demo.queue(max_size=32, api_open=False)
 
84
 
85
  if __name__ == "__main__":
86
- demo.launch(share=False, show_error=True)
 
1
 
2
+ import os, torch, gradio as gr
 
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
4
 
5
+ TITLE = "LanguageBridge — Multimodal Chatbot (Mistral-7B)"
6
  MODEL_ID = os.getenv("MODEL_ID", "aciang/mistral7b-tk-sft-20251019-merged")
7
 
8
  SYSTEM_PROMPT = (
9
+ "你是『語言橋』學習助教。規則:"
10
+ "1) 嚴謹、分段、先重點後細節;"
11
+ "2) 若為數學/規則題:先列步驟,再給最終答案;"
12
+ "3) 若資訊不足,請明確指出缺口,勿捏造;"
13
+ "4) 優先以繁體中文回答。"
14
  )
15
 
 
16
  def load_llm():
17
+ # 4-bit 後援;失敗則 fp16/CPU fallback
18
+ kwargs = dict(trust_remote_code=False)
 
 
19
  bnb = BitsAndBytesConfig(
20
  load_in_4bit=True, bnb_4bit_quant_type="nf4",
21
  bnb_4bit_use_double_quant=True,
22
  bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
23
  )
 
24
  try:
25
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, quantization_config=bnb, device_map="auto", **kwargs)
26
  except Exception as e:
27
+ print("[4-bit failed] -> fp16/CPU fallback:", e)
28
+ model = AutoModelForCausalLM.from_pretrained(
29
  MODEL_ID,
30
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
31
+ device_map="auto" if torch.cuda.is_available() else None,
32
+ **kwargs
33
  )
34
+ tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
35
+ if tok.pad_token is None: tok.pad_token = tok.eos_token
36
+ tok.padding_side = "left"
37
+ if torch.cuda.is_available():
38
+ torch.backends.cuda.matmul.allow_tf32 = True
39
+ model.config.use_cache = True
40
+ return tok, model
 
 
41
 
42
+ tokenizer, llm = load_llm(); llm.eval()
 
 
 
 
 
 
43
 
44
+ def build_prompt(context, question, longform):
45
+ head = SYSTEM_PROMPT
46
+ if context.strip():
47
+ head += f"\n\n[上下文]\n{context.strip()}"
48
+ ask = f"\n\n[問題]\n{question.strip()}\n"
49
+ tail = "\n請以條列步驟與小結回覆;若可計算,先算再答。"
50
+ if longform:
51
+ tail += "\n(長文模式)請分段、標題化、最後給出『摘要重點』。"
52
+ return head + ask + tail
53
 
54
+ @torch.inference_mode()
55
+ def stream_answer(context, question, longform, mx, temp, top_p, rep):
56
+ prompt = build_prompt(context, question, longform)
57
+ inputs = tokenizer(prompt, return_tensors="pt").to(llm.device)
58
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
59
+ gen = dict(
60
+ **inputs, streamer=streamer, max_new_tokens=int(mx),
61
+ temperature=float(temp), top_p=float(top_p),
62
+ repetition_penalty=float(rep),
63
+ do_sample=True if float(temp)>0 else False,
64
+ eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id
65
+ )
66
  import threading
67
+ t = threading.Thread(target=llm.generate, kwargs=gen); t.start()
68
+ buf=""
 
69
  for tok in streamer:
70
  buf += tok
71
  yield buf
72
 
73
+ def warmup():
74
+ try:
75
+ _ = list(stream_answer("", "簡述本系統的用途。", False, 96, 0.2, 0.9, 1.05))[-1]
76
+ print("[warmup] done")
77
+ except Exception as e:
78
+ print("[warmup] skip:", e)
79
+
80
+ with gr.Blocks(title=TITLE, theme="soft") as demo:
81
+ gr.Markdown(f"## {TITLE}\n模型:`{MODEL_ID}`|已啟用:上下文欄位、長文模式、流式輸出、暖機")
82
+ with gr.Row():
83
+ ctx = gr.Textbox(label="上下文(長文,可空白)", placeholder="選填的背景內容/段落/資料摘錄", lines=6)
84
+ with gr.Row():
85
+ q = gr.Textbox(label="問題 / 指令", placeholder="請清楚描述你的問題", lines=3)
86
+ with gr.Row():
87
+ longf = gr.Checkbox(label="長文模式(章節化 + 摘要)", value=True)
88
  with gr.Row():
89
+ mx = gr.Slider(128, 1024, value=512, step=32, label="max_new_tokens")
90
+ temp = gr.Slider(0.0, 0.8, value=0.2, step=0.05, label="temperature")
91
+ top = gr.Slider(0.6, 1.0, value=0.9, step=0.01, label="top_p")
92
+ rep = gr.Slider(1.0, 1.3, value=1.05, step=0.01, label="repetition_penalty")
93
+ go = gr.Button("送出 🚀", variant="primary")
94
+ out = gr.Textbox(label="輸出(流式)", lines=14)
95
+ clr = gr.Button("清除")
96
 
97
+ go.click(stream_answer, inputs=[ctx,q,longf,mx,temp,top,rep], outputs=out)
98
+ clr.click(lambda:"", outputs=out)
99
 
100
+ demo.queue(concurrency_count=4, max_size=32, api_open=False)
101
+ warmup()
102
 
103
  if __name__ == "__main__":
104
+ demo.launch(share=False, server_name="0.0.0.0", server_port=7860, show_error=True)