aciang commited on
Commit
2770efa
·
verified ·
1 Parent(s): d9481ec

Hotfix: 4-bit fallback + hf_transfer + stable cache

Browse files
Files changed (1) hide show
  1. app.py +86 -119
app.py CHANGED
@@ -1,132 +1,99 @@
1
 
2
- import os, time, torch, gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
4
-
5
- TITLE = os.getenv("SPACE_TITLE", "LanguageBridge — Multimodal Chatbot (Mistral-7B)")
6
- MODEL_ID = os.getenv("MODEL_ID", "aciang/mistral7b-tk-sft-20251019-merged")
7
-
8
- SYSTEM_PROMPT = (
9
- "你是語言橋助教。原則:1) 先條列必要重點;2) 再給最終結論;3) 嚴禁瞎掰,不足就說明。"
10
- )
11
-
12
- def load_llm():
13
- # 4-bit 優先,節省顯存
14
- bnb = BitsAndBytesConfig(
15
- load_in_4bit=True, bnb_4bit_quant_type="nf4",
16
- bnb_4bit_use_double_quant=True,
17
- bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
18
- )
19
- kwargs = dict(device_map="auto", quantization_config=bnb, trust_remote_code=False)
20
  try:
21
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
22
  except Exception as e:
23
- print("[4-bit failed] → fallback:", e)
24
- kwargs.pop("quantization_config", None)
25
- model = AutoModelForCausalLM.from_pretrained(
26
- MODEL_ID,
27
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
28
- device_map="auto" if torch.cuda.is_available() else None,
29
- trust_remote_code=False
 
 
 
 
 
30
  )
31
- tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
32
- if tok.pad_token is None:
33
- tok.pad_token = tok.eos_token
 
 
 
 
 
34
  tok.padding_side = "left"
35
- if torch.cuda.is_available():
36
- torch.backends.cuda.matmul.allow_tf32 = True
37
- model.config.use_cache = True
38
- return tok, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- tokenizer, llm = load_llm(); llm.eval()
41
 
42
- def build_prompt(user_text:str)->str:
43
- return f"{SYSTEM_PROMPT}\n\n使用者:{user_text}\n助教:"
 
44
 
45
  @torch.inference_mode()
46
- def stream_answer(history, text, mx=256, temp=0.2, top_p=0.95):
47
- # history 為 list[{"role":"user|assistant","content":"..."}]
48
- user_text = text or ""
49
- prompt = build_prompt(user_text)
50
  inputs = tokenizer(prompt, return_tensors="pt").to(llm.device)
51
-
52
- from transformers import TextIteratorStreamer
53
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
54
- gen_kwargs = dict(
55
- **inputs, streamer=streamer, max_new_tokens=int(mx),
56
- temperature=float(temp), top_p=float(top_p),
57
- do_sample=True if float(temp)>0 else False,
58
- eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id
59
- )
60
-
61
- import threading
62
- t = threading.Thread(target=llm.generate, kwargs=gen_kwargs); t.start()
63
- partial = ""
64
- for piece in streamer:
65
- partial += piece
66
- yield partial
67
-
68
- def warmup():
69
- try:
70
- _ = list(stream_answer([], "π 約為多少?", mx=32))[-1]
71
- print("[warmup] done")
72
- except Exception as e:
73
- print("[warmup] skip:", e)
74
-
75
- with gr.Blocks(title=TITLE, theme="soft") as demo:
76
- gr.Markdown(f"## {TITLE}\n模型:`{MODEL_ID}`|已修正訊息格式;預設短答低延遲(流式 + 暖機)")
77
-
78
- chat = gr.Chatbot(label="Chatbot", type="messages", height=420, show_copy_button=True)
79
- text = gr.Textbox(label="你的問題 / 指令", placeholder="請輸入文字…", lines=3)
80
- mx = gr.Slider(64, 1024, value=256, step=32, label="max_new_tokens")
81
- temp = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="temperature")
82
- top = gr.Slider(0.6, 1.0, value=0.95, step=0.01, label="top_p")
83
- go = gr.Button("送出 🚀", variant="primary")
84
- clr = gr.Button("清除")
85
-
86
- def respond(history, text, mx, temp, top):
87
- history = history or []
88
- if text and text.strip():
89
- history.append({"role":"user","content":text})
90
- # 流式生成
91
- stream = stream_answer(history, text, mx, temp, top)
92
- out = ""
93
- for chunk in stream:
94
- out = chunk
95
- yield history + [{"role":"assistant","content":out}], ""
96
- history.append({"role":"assistant","content":out})
97
- yield history, ""
98
-
99
- go.click(respond, inputs=[chat, text, mx, temp, top], outputs=[chat, text])
100
- clr.click(lambda: ([], ""), outputs=[chat, text])
101
-
102
- # 不要使用舊的 queue 參數(避免 Runtime error)
103
- demo.queue()
104
- warmup()
105
 
106
  if __name__ == "__main__":
107
- demo.launch(share=False, server_name="0.0.0.0", server_port=7860, show_error=True)
108
-
109
-
110
- def _coerce_messages(history):
111
- """確保是 [{'role','content'}] 形式。"""
112
- fixed = []
113
- for r, m in (history or []):
114
- if isinstance(m, dict) and 'role' in m and 'content' in m:
115
- fixed.append((r, m))
116
- elif isinstance(m, str):
117
- fixed.append((r, {"role": r, "content": m}))
118
- else:
119
- # 最保守兜底
120
- fixed.append((r, {"role": r, "content": str(m)}))
121
- return fixed
122
-
123
- def respond(history, text, image, audio, mx, tp, top):
124
- history = _coerce_messages(history)
125
- history.append(("user", {"content": text}))
126
- try:
127
- ans = generate_reply(history, image, audio, mx, tp, top)
128
- except Exception as e:
129
- ans = f"(推理失敗:{e})"
130
- history.append(("assistant", {"content": ans}))
131
- return history, ""
132
-
 
1
 
2
+ import os, time, threading, torch, gradio as gr
3
+ from huggingface_hub import snapshot_download
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
5
+
6
+ SPACE_TITLE = "LanguageBridge — Multimodal Chatbot (Mistral-7B)"
7
+ PRIMARY_MODEL = "aciang/mistral7b-tk-sft-20251019-merged"
8
+ FALLBACK_MODEL = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
9
+
10
+ # ---- 加速下載 + 固定快取(/data 在 Spaces 會持久化)----
11
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER","1")
12
+ os.environ.setdefault("HF_HOME","/data/.cache/hf") # 持久化 cache
13
+ os.environ.setdefault("TRANSFORMERS_CACHE","/data/.cache/hf/transformers")
14
+ os.makedirs(os.environ["HF_HOME"], exist_ok=True)
15
+
16
+ # ---- 嘗試先把 tokenizer 拉到本地(秒載 UI)----
17
+ def _ensure_tokenizer(model_id):
 
 
18
  try:
19
+ snapshot_download(model_id, allow_patterns=["tokenizer.*","*tokenizer*","special_tokens_map.json"], local_dir=None)
20
  except Exception as e:
21
+ print("[tok prefetch] skip:", e)
22
+
23
+ # ---- 模型載入(含 4-bit 後援)----
24
+ def load_llm(prefer_primary=True):
25
+ model_id = PRIMARY_MODEL if prefer_primary else FALLBACK_MODEL
26
+ use_4bit = (model_id != PRIMARY_MODEL)
27
+
28
+ if use_4bit:
29
+ bnb = BitsAndBytesConfig(
30
+ load_in_4bit=True, bnb_4bit_quant_type="nf4",
31
+ bnb_4bit_use_double_quant=True,
32
+ bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
33
  )
34
+ kw = dict(device_map="auto", quantization_config=bnb, trust_remote_code=False)
35
+ else:
36
+ kw = dict(device_map="auto", trust_remote_code=False)
37
+
38
+ print(f"[load] try model = {model_id} | 4bit={use_4bit}")
39
+ t0 = time.time()
40
+ tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
41
+ if tok.pad_token is None: tok.pad_token = tok.eos_token
42
  tok.padding_side = "left"
43
+ mdl = AutoModelForCausalLM.from_pretrained(model_id, **kw)
44
+ mdl.eval()
45
+ print(f"[load] ok in {time.time()-t0:.1f}s")
46
+ return tok, mdl, model_id
47
+
48
+ # ---- 啟動邏輯:先開一條背景線程載入 PRIMARY;若超時改載入 FALLBACK ----
49
+ tokenizer = None
50
+ llm = None
51
+ active_model = None
52
+
53
+ def boot():
54
+ global tokenizer, llm, active_model
55
+ _ensure_tokenizer(PRIMARY_MODEL)
56
+ deadline = time.time() + 14*60 # 14 分鐘內載不完就切換(留 16 分緩衝 < 30 分鐘)
57
+ try:
58
+ tokenizer, llm, active_model = load_llm(prefer_primary=True)
59
+ except Exception as e:
60
+ print("[boot] primary failed early:", e)
61
+ if llm is None or time.time() > deadline:
62
+ print("[boot] switching to FALLBACK for fast availability...")
63
+ tokenizer, llm, active_model = load_llm(prefer_primary=False)
64
 
65
+ boot_th = threading.Thread(target=boot); boot_th.start()
66
 
67
+ SYSTEM = (
68
+ "你是語言橋助教。回覆重點:1) 條列步驟 2) 簡潔正確 3) 不確定就說明不足並提出假設。"
69
+ )
70
 
71
  @torch.inference_mode()
72
+ def stream_answer(q, mx=256, temp=0.6, top_p=0.95):
73
+ boot_th.join() # 確保載入完成
74
+ prompt = f"{SYSTEM}\\n\\n使用者:{q}\\n助教:"
 
75
  inputs = tokenizer(prompt, return_tensors="pt").to(llm.device)
 
 
76
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
77
+ gen = dict(**inputs, streamer=streamer, max_new_tokens=int(mx),
78
+ temperature=float(temp), top_p=float(top_p),
79
+ do_sample=True if float(temp)>0 else False,
80
+ eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id)
81
+ t = threading.Thread(target=llm.generate, kwargs=gen); t.start()
82
+ buf = ""
83
+ for tok in streamer:
84
+ buf += tok
85
+ yield buf
86
+
87
+ with gr.Blocks(title=SPACE_TITLE, fill_height=True) as demo:
88
+ gr.Markdown(f"### {SPACE_TITLE}\\n目前模型:`{active_model or 'loading…'}`\\n(首次啟動若超時將自動切到 4-bit 權重)")
89
+ q = gr.Textbox(label="你的問題 / 指令")
90
+ mx = gr.Slider(64, 1024, value=512, step=32, label="max_new_tokens")
91
+ tp = gr.Slider(0.0, 1.2, value=0.6, step=0.05, label="temperature")
92
+ top = gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="top_p")
93
+ go = gr.Button("送出 🚀", variant="primary")
94
+ out = gr.Textbox(label="輸出", lines=12)
95
+ go.click(stream_answer, inputs=[q, mx, tp, top], outputs=out)
96
+ demo.queue(api_open=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  if __name__ == "__main__":
99
+ demo.launch(share=False, show_error=True)