Spaces:
Runtime error
Runtime error
Hotfix: 4-bit fallback + hf_transfer + stable cache
Browse files
app.py
CHANGED
|
@@ -1,132 +1,99 @@
|
|
| 1 |
|
| 2 |
-
import os, time, torch, gradio as gr
|
| 3 |
-
from
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
)
|
| 19 |
-
kwargs = dict(device_map="auto", quantization_config=bnb, trust_remote_code=False)
|
| 20 |
try:
|
| 21 |
-
|
| 22 |
except Exception as e:
|
| 23 |
-
print("[
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
tok.padding_side = "left"
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
return tok,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
|
|
|
| 44 |
|
| 45 |
@torch.inference_mode()
|
| 46 |
-
def stream_answer(
|
| 47 |
-
#
|
| 48 |
-
|
| 49 |
-
prompt = build_prompt(user_text)
|
| 50 |
inputs = tokenizer(prompt, return_tensors="pt").to(llm.device)
|
| 51 |
-
|
| 52 |
-
from transformers import TextIteratorStreamer
|
| 53 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
with gr.Blocks(title=TITLE, theme="soft") as demo:
|
| 76 |
-
gr.Markdown(f"## {TITLE}\n模型:`{MODEL_ID}`|已修正訊息格式;預設短答低延遲(流式 + 暖機)")
|
| 77 |
-
|
| 78 |
-
chat = gr.Chatbot(label="Chatbot", type="messages", height=420, show_copy_button=True)
|
| 79 |
-
text = gr.Textbox(label="你的問題 / 指令", placeholder="請輸入文字…", lines=3)
|
| 80 |
-
mx = gr.Slider(64, 1024, value=256, step=32, label="max_new_tokens")
|
| 81 |
-
temp = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="temperature")
|
| 82 |
-
top = gr.Slider(0.6, 1.0, value=0.95, step=0.01, label="top_p")
|
| 83 |
-
go = gr.Button("送出 🚀", variant="primary")
|
| 84 |
-
clr = gr.Button("清除")
|
| 85 |
-
|
| 86 |
-
def respond(history, text, mx, temp, top):
|
| 87 |
-
history = history or []
|
| 88 |
-
if text and text.strip():
|
| 89 |
-
history.append({"role":"user","content":text})
|
| 90 |
-
# 流式生成
|
| 91 |
-
stream = stream_answer(history, text, mx, temp, top)
|
| 92 |
-
out = ""
|
| 93 |
-
for chunk in stream:
|
| 94 |
-
out = chunk
|
| 95 |
-
yield history + [{"role":"assistant","content":out}], ""
|
| 96 |
-
history.append({"role":"assistant","content":out})
|
| 97 |
-
yield history, ""
|
| 98 |
-
|
| 99 |
-
go.click(respond, inputs=[chat, text, mx, temp, top], outputs=[chat, text])
|
| 100 |
-
clr.click(lambda: ([], ""), outputs=[chat, text])
|
| 101 |
-
|
| 102 |
-
# 不要使用舊的 queue 參數(避免 Runtime error)
|
| 103 |
-
demo.queue()
|
| 104 |
-
warmup()
|
| 105 |
|
| 106 |
if __name__ == "__main__":
|
| 107 |
-
demo.launch(share=False,
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def _coerce_messages(history):
|
| 111 |
-
"""確保是 [{'role','content'}] 形式。"""
|
| 112 |
-
fixed = []
|
| 113 |
-
for r, m in (history or []):
|
| 114 |
-
if isinstance(m, dict) and 'role' in m and 'content' in m:
|
| 115 |
-
fixed.append((r, m))
|
| 116 |
-
elif isinstance(m, str):
|
| 117 |
-
fixed.append((r, {"role": r, "content": m}))
|
| 118 |
-
else:
|
| 119 |
-
# 最保守兜底
|
| 120 |
-
fixed.append((r, {"role": r, "content": str(m)}))
|
| 121 |
-
return fixed
|
| 122 |
-
|
| 123 |
-
def respond(history, text, image, audio, mx, tp, top):
|
| 124 |
-
history = _coerce_messages(history)
|
| 125 |
-
history.append(("user", {"content": text}))
|
| 126 |
-
try:
|
| 127 |
-
ans = generate_reply(history, image, audio, mx, tp, top)
|
| 128 |
-
except Exception as e:
|
| 129 |
-
ans = f"(推理失敗:{e})"
|
| 130 |
-
history.append(("assistant", {"content": ans}))
|
| 131 |
-
return history, ""
|
| 132 |
-
|
|
|
|
| 1 |
|
| 2 |
+
import os, time, threading, torch, gradio as gr
|
| 3 |
+
from huggingface_hub import snapshot_download
|
| 4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
|
| 5 |
+
|
| 6 |
+
SPACE_TITLE = "LanguageBridge — Multimodal Chatbot (Mistral-7B)"
|
| 7 |
+
PRIMARY_MODEL = "aciang/mistral7b-tk-sft-20251019-merged"
|
| 8 |
+
FALLBACK_MODEL = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
|
| 9 |
+
|
| 10 |
+
# ---- 加速下載 + 固定快取(/data 在 Spaces 會持久化)----
|
| 11 |
+
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER","1")
|
| 12 |
+
os.environ.setdefault("HF_HOME","/data/.cache/hf") # 持久化 cache
|
| 13 |
+
os.environ.setdefault("TRANSFORMERS_CACHE","/data/.cache/hf/transformers")
|
| 14 |
+
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
|
| 15 |
+
|
| 16 |
+
# ---- 嘗試先把 tokenizer 拉到本地(秒載 UI)----
|
| 17 |
+
def _ensure_tokenizer(model_id):
|
|
|
|
|
|
|
| 18 |
try:
|
| 19 |
+
snapshot_download(model_id, allow_patterns=["tokenizer.*","*tokenizer*","special_tokens_map.json"], local_dir=None)
|
| 20 |
except Exception as e:
|
| 21 |
+
print("[tok prefetch] skip:", e)
|
| 22 |
+
|
| 23 |
+
# ---- 模型載入(含 4-bit 後援)----
|
| 24 |
+
def load_llm(prefer_primary=True):
|
| 25 |
+
model_id = PRIMARY_MODEL if prefer_primary else FALLBACK_MODEL
|
| 26 |
+
use_4bit = (model_id != PRIMARY_MODEL)
|
| 27 |
+
|
| 28 |
+
if use_4bit:
|
| 29 |
+
bnb = BitsAndBytesConfig(
|
| 30 |
+
load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
| 31 |
+
bnb_4bit_use_double_quant=True,
|
| 32 |
+
bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
|
| 33 |
)
|
| 34 |
+
kw = dict(device_map="auto", quantization_config=bnb, trust_remote_code=False)
|
| 35 |
+
else:
|
| 36 |
+
kw = dict(device_map="auto", trust_remote_code=False)
|
| 37 |
+
|
| 38 |
+
print(f"[load] try model = {model_id} | 4bit={use_4bit}")
|
| 39 |
+
t0 = time.time()
|
| 40 |
+
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
| 41 |
+
if tok.pad_token is None: tok.pad_token = tok.eos_token
|
| 42 |
tok.padding_side = "left"
|
| 43 |
+
mdl = AutoModelForCausalLM.from_pretrained(model_id, **kw)
|
| 44 |
+
mdl.eval()
|
| 45 |
+
print(f"[load] ok in {time.time()-t0:.1f}s")
|
| 46 |
+
return tok, mdl, model_id
|
| 47 |
+
|
| 48 |
+
# ---- 啟動邏輯:先開一條背景線程載入 PRIMARY;若超時改載入 FALLBACK ----
|
| 49 |
+
tokenizer = None
|
| 50 |
+
llm = None
|
| 51 |
+
active_model = None
|
| 52 |
+
|
| 53 |
+
def boot():
|
| 54 |
+
global tokenizer, llm, active_model
|
| 55 |
+
_ensure_tokenizer(PRIMARY_MODEL)
|
| 56 |
+
deadline = time.time() + 14*60 # 14 分鐘內載不完就切換(留 16 分緩衝 < 30 分鐘)
|
| 57 |
+
try:
|
| 58 |
+
tokenizer, llm, active_model = load_llm(prefer_primary=True)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print("[boot] primary failed early:", e)
|
| 61 |
+
if llm is None or time.time() > deadline:
|
| 62 |
+
print("[boot] switching to FALLBACK for fast availability...")
|
| 63 |
+
tokenizer, llm, active_model = load_llm(prefer_primary=False)
|
| 64 |
|
| 65 |
+
boot_th = threading.Thread(target=boot); boot_th.start()
|
| 66 |
|
| 67 |
+
SYSTEM = (
|
| 68 |
+
"你是語言橋助教。回覆重點:1) 條列步驟 2) 簡潔正確 3) 不確定就說明不足並提出假設。"
|
| 69 |
+
)
|
| 70 |
|
| 71 |
@torch.inference_mode()
|
| 72 |
+
def stream_answer(q, mx=256, temp=0.6, top_p=0.95):
|
| 73 |
+
boot_th.join() # 確保載入完成
|
| 74 |
+
prompt = f"{SYSTEM}\\n\\n使用者:{q}\\n助教:"
|
|
|
|
| 75 |
inputs = tokenizer(prompt, return_tensors="pt").to(llm.device)
|
|
|
|
|
|
|
| 76 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 77 |
+
gen = dict(**inputs, streamer=streamer, max_new_tokens=int(mx),
|
| 78 |
+
temperature=float(temp), top_p=float(top_p),
|
| 79 |
+
do_sample=True if float(temp)>0 else False,
|
| 80 |
+
eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id)
|
| 81 |
+
t = threading.Thread(target=llm.generate, kwargs=gen); t.start()
|
| 82 |
+
buf = ""
|
| 83 |
+
for tok in streamer:
|
| 84 |
+
buf += tok
|
| 85 |
+
yield buf
|
| 86 |
+
|
| 87 |
+
with gr.Blocks(title=SPACE_TITLE, fill_height=True) as demo:
|
| 88 |
+
gr.Markdown(f"### {SPACE_TITLE}\\n目前模型:`{active_model or 'loading…'}`\\n(首次啟動若超時將自動切到 4-bit 權重)")
|
| 89 |
+
q = gr.Textbox(label="你的問題 / 指令")
|
| 90 |
+
mx = gr.Slider(64, 1024, value=512, step=32, label="max_new_tokens")
|
| 91 |
+
tp = gr.Slider(0.0, 1.2, value=0.6, step=0.05, label="temperature")
|
| 92 |
+
top = gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="top_p")
|
| 93 |
+
go = gr.Button("送出 🚀", variant="primary")
|
| 94 |
+
out = gr.Textbox(label="輸出", lines=12)
|
| 95 |
+
go.click(stream_answer, inputs=[q, mx, tp, top], outputs=out)
|
| 96 |
+
demo.queue(api_open=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
if __name__ == "__main__":
|
| 99 |
+
demo.launch(share=False, show_error=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|