mini-llm / app.py
caixiaoshun's picture
Update app.py
d2605bc verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import List, Dict, Optional, Tuple
import time
# ==========================================
# Helper: dtype map & loader with simple cache
# ==========================================
DTYPE_MAP = {
"auto": "auto",
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}
_MODEL_CACHE = {}
def _dtype_from_name(name: str):
return DTYPE_MAP.get(name, "auto")
def load_model_and_tokenizer(repo_id: str, device_map: str = "cpu", dtype_name: str = "auto"):
"""Load & cache (tokenizer, model) keyed by (repo_id, device_map, dtype). No low_cpu_mem_usage.
Prefer `dtype=...`; on TypeError fallback to `torch_dtype=` or omit.
"""
key = (repo_id, device_map, dtype_name)
if key in _MODEL_CACHE:
return _MODEL_CACHE[key]
tok = AutoTokenizer.from_pretrained(repo_id)
dtype_val = _dtype_from_name(dtype_name)
common_kwargs = dict(trust_remote_code=True, device_map=device_map)
model = None
try:
if dtype_name == "auto":
model = AutoModelForCausalLM.from_pretrained(repo_id, dtype="auto", **common_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(repo_id, dtype=dtype_val, **common_kwargs)
except TypeError:
if dtype_name == "auto":
model = AutoModelForCausalLM.from_pretrained(repo_id, **common_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=dtype_val, **common_kwargs)
model.eval()
_MODEL_CACHE[key] = (tok, model)
return tok, model
# ==========================================
# Chat utilities & logging helpers
# ==========================================
def messages_to_pairs(messages: List[Dict[str, str]]) -> List[Tuple[str, str]]:
"""Convert role-based messages into (user, assistant) pairs for Gradio Chatbot."""
pairs: List[Tuple[str, str]] = []
i = 0
while i < len(messages):
msg = messages[i]
if msg.get("role") == "user":
user = msg.get("content", "")
if i + 1 < len(messages) and messages[i + 1].get("role") == "assistant":
assistant = messages[i + 1].get("content", "")
pairs.append((user, assistant))
i += 2
else:
pairs.append((user, ""))
i += 1
else:
# Skip unexpected assistant-first cases
i += 1
return pairs
def _ts() -> str:
return time.strftime("%H:%M:%S")
def append_log(logs: str, msg: str) -> str:
line = f"[{_ts()}] {msg}\n"
return (logs + line) if logs else line
# ==========================================
# Model state helpers (reload only when repo_id changes)
# ==========================================
def ensure_model(model_state: Dict, repo_id: str, device_map: str, dtype_name: str, logs: str):
ms = model_state or {"repo_id": None, "tok": None, "model": None}
if ms.get("repo_id") != repo_id or ms.get("model") is None:
logs = append_log(logs, f"加载模型 {repo_id}(触发:repo 变更)…")
tok, mdl = load_model_and_tokenizer(repo_id, device_map=device_map, dtype_name=dtype_name)
ms = {"repo_id": repo_id, "tok": tok, "model": mdl}
logs = append_log(logs, "模型加载完成。")
else:
logs = append_log(logs, f"使用已加载模型 {repo_id}(缓存)")
return ms, ms["tok"], ms["model"], logs
# ==========================================
# Predict
# ==========================================
def predict(user_text: str,
messages_state: List[Dict[str, str]],
repo_id: str, device_map: str, dtype_name: str,
max_new_token: int, top_k: int,
logs_state: str,
model_state: Dict):
messages_state = messages_state or []
logs_state = logs_state or ""
# 1) Ensure model based on repo_id only
model_state, tokenizer, model, logs_state = ensure_model(model_state, repo_id, device_map, dtype_name, logs_state)
# 2) Append user & paint
messages_state.append({"role": "user", "content": user_text or ""})
logs_state = append_log(logs_state, f"收到输入:{(user_text or '').strip()[:60]}")
chat_display = messages_to_pairs(messages_state)
yield chat_display, messages_state, logs_state, logs_state, model_state
# 3) Inference
try:
logs_state = append_log(logs_state, f"开始推理:max_new_token={int(max_new_token)}, top_k={int(top_k)}")
yield chat_display, messages_state, logs_state, logs_state, model_state
try:
output = model.chat(
messages_state,
tokenizer,
max_new_token=int(max_new_token),
top_k=int(top_k),
)
except TypeError:
output = model.chat(messages_state, tokenizer)
partial = ""
for ch in str(output):
partial += ch
chat_display[-1] = (chat_display[-1][0], partial)
yield chat_display, messages_state, logs_state, logs_state, model_state
messages_state.append({"role": "assistant", "content": str(output)})
logs_state = append_log(logs_state, f"推理完成,输出长度 {len(str(output))} 字符。")
yield chat_display, messages_state, logs_state, logs_state, model_state
except Exception as e:
err = f"[推理错误] {e}"
logs_state = append_log(logs_state, err)
chat_display[-1] = (chat_display[-1][0], err)
messages_state.append({"role": "assistant", "content": err})
yield chat_display, messages_state, logs_state, logs_state, model_state
def clear_chat():
return [], [] # chatbot pairs, messages_state
def clear_logs_fn():
return "", "" # logs_box text, logs_state
def preload_on_repo_change(repo_id: str, device_map: str, dtype_name: str, logs_state: str, model_state: Dict):
"""当仓库切换时,预加载模型并写日志。"""
logs_state = logs_state or ""
model_state, _, _, logs_state = ensure_model(model_state, repo_id, device_map, dtype_name, logs_state)
return logs_state, model_state
# ==========================================
# Gradio UI
# ==========================================
with gr.Blocks(title="mini-moe Chat (Gradio)") as demo:
messages_state = gr.State([]) # 保存 role/content 历史
logs_state = gr.State("") # 保存日志文本
model_state = gr.State({"repo_id": None, "tok": None, "model": None}) # 当前已加载模型
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="对话", height=520)
with gr.Row():
user_box = gr.Textbox(label="输入", placeholder="请输入你的问题… (Shift+Enter 换行)", lines=3)
with gr.Row():
send_btn = gr.Button("发送", variant="primary")
clear_btn = gr.Button("清空对话")
with gr.Column(scale=1):
gr.Markdown("## ⚙️ 设置")
repo_dd = gr.Dropdown(
label="模型仓库 (HF repo)",
choices=["caixiaoshun/mini-moe", "caixiaoshun/mini-llm"],
value="caixiaoshun/mini-moe",
)
device_dd = gr.Dropdown(label="device_map", choices=["cpu", "auto"], value="cpu")
dtype_dd = gr.Dropdown(label="精度 (dtype/torch_dtype)", choices=["auto", "float32", "bfloat16", "float16"], value="auto")
max_new_num = gr.Number(label="max_new_token", value=256, precision=0)
top_k_num = gr.Number(label="top_k", value=5, precision=0)
with gr.Accordion("📜 日志 (展开查看)", open=False):
logs_box = gr.Textbox(label="运行日志", lines=12, interactive=False)
log_clear_btn = gr.Button("清空日志")
# Events: send / submit
send_evt_inputs = [
user_box, messages_state, repo_dd, device_dd, dtype_dd, max_new_num, top_k_num, logs_state, model_state
]
send_evt_outputs = [chatbot, messages_state, logs_box, logs_state, model_state]
send_btn.click(predict, inputs=send_evt_inputs, outputs=send_evt_outputs)
user_box.submit(predict, inputs=send_evt_inputs, outputs=send_evt_outputs)
# Clear input after send
def _clear_text():
return ""
send_btn.click(_clear_text, inputs=None, outputs=user_box)
user_box.submit(_clear_text, inputs=None, outputs=user_box)
# Clear chat
clear_btn.click(clear_chat, inputs=None, outputs=[chatbot, messages_state])
# Clear logs
log_clear_btn.click(clear_logs_fn, inputs=None, outputs=[logs_box, logs_state])
# Preload on repo change (only reload on repo change)
repo_dd.change(preload_on_repo_change,
inputs=[repo_dd, device_dd, dtype_dd, logs_state, model_state],
outputs=[logs_box, model_state])
if __name__ == "__main__":
demo.queue().launch() # set share=True if you want a public link