import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch from typing import List, Dict, Optional, Tuple import time # ========================================== # Helper: dtype map & loader with simple cache # ========================================== DTYPE_MAP = { "auto": "auto", "float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16, } _MODEL_CACHE = {} def _dtype_from_name(name: str): return DTYPE_MAP.get(name, "auto") def load_model_and_tokenizer(repo_id: str, device_map: str = "cpu", dtype_name: str = "auto"): """Load & cache (tokenizer, model) keyed by (repo_id, device_map, dtype). No low_cpu_mem_usage. Prefer `dtype=...`; on TypeError fallback to `torch_dtype=` or omit. """ key = (repo_id, device_map, dtype_name) if key in _MODEL_CACHE: return _MODEL_CACHE[key] tok = AutoTokenizer.from_pretrained(repo_id) dtype_val = _dtype_from_name(dtype_name) common_kwargs = dict(trust_remote_code=True, device_map=device_map) model = None try: if dtype_name == "auto": model = AutoModelForCausalLM.from_pretrained(repo_id, dtype="auto", **common_kwargs) else: model = AutoModelForCausalLM.from_pretrained(repo_id, dtype=dtype_val, **common_kwargs) except TypeError: if dtype_name == "auto": model = AutoModelForCausalLM.from_pretrained(repo_id, **common_kwargs) else: model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=dtype_val, **common_kwargs) model.eval() _MODEL_CACHE[key] = (tok, model) return tok, model # ========================================== # Chat utilities & logging helpers # ========================================== def messages_to_pairs(messages: List[Dict[str, str]]) -> List[Tuple[str, str]]: """Convert role-based messages into (user, assistant) pairs for Gradio Chatbot.""" pairs: List[Tuple[str, str]] = [] i = 0 while i < len(messages): msg = messages[i] if msg.get("role") == "user": user = msg.get("content", "") if i + 1 < len(messages) and messages[i + 1].get("role") == "assistant": assistant = messages[i + 1].get("content", "") pairs.append((user, assistant)) i += 2 else: pairs.append((user, "")) i += 1 else: # Skip unexpected assistant-first cases i += 1 return pairs def _ts() -> str: return time.strftime("%H:%M:%S") def append_log(logs: str, msg: str) -> str: line = f"[{_ts()}] {msg}\n" return (logs + line) if logs else line # ========================================== # Model state helpers (reload only when repo_id changes) # ========================================== def ensure_model(model_state: Dict, repo_id: str, device_map: str, dtype_name: str, logs: str): ms = model_state or {"repo_id": None, "tok": None, "model": None} if ms.get("repo_id") != repo_id or ms.get("model") is None: logs = append_log(logs, f"加载模型 {repo_id}(触发:repo 变更)…") tok, mdl = load_model_and_tokenizer(repo_id, device_map=device_map, dtype_name=dtype_name) ms = {"repo_id": repo_id, "tok": tok, "model": mdl} logs = append_log(logs, "模型加载完成。") else: logs = append_log(logs, f"使用已加载模型 {repo_id}(缓存)") return ms, ms["tok"], ms["model"], logs # ========================================== # Predict # ========================================== def predict(user_text: str, messages_state: List[Dict[str, str]], repo_id: str, device_map: str, dtype_name: str, max_new_token: int, top_k: int, logs_state: str, model_state: Dict): messages_state = messages_state or [] logs_state = logs_state or "" # 1) Ensure model based on repo_id only model_state, tokenizer, model, logs_state = ensure_model(model_state, repo_id, device_map, dtype_name, logs_state) # 2) Append user & paint messages_state.append({"role": "user", "content": user_text or ""}) logs_state = append_log(logs_state, f"收到输入:{(user_text or '').strip()[:60]}") chat_display = messages_to_pairs(messages_state) yield chat_display, messages_state, logs_state, logs_state, model_state # 3) Inference try: logs_state = append_log(logs_state, f"开始推理:max_new_token={int(max_new_token)}, top_k={int(top_k)}") yield chat_display, messages_state, logs_state, logs_state, model_state try: output = model.chat( messages_state, tokenizer, max_new_token=int(max_new_token), top_k=int(top_k), ) except TypeError: output = model.chat(messages_state, tokenizer) partial = "" for ch in str(output): partial += ch chat_display[-1] = (chat_display[-1][0], partial) yield chat_display, messages_state, logs_state, logs_state, model_state messages_state.append({"role": "assistant", "content": str(output)}) logs_state = append_log(logs_state, f"推理完成,输出长度 {len(str(output))} 字符。") yield chat_display, messages_state, logs_state, logs_state, model_state except Exception as e: err = f"[推理错误] {e}" logs_state = append_log(logs_state, err) chat_display[-1] = (chat_display[-1][0], err) messages_state.append({"role": "assistant", "content": err}) yield chat_display, messages_state, logs_state, logs_state, model_state def clear_chat(): return [], [] # chatbot pairs, messages_state def clear_logs_fn(): return "", "" # logs_box text, logs_state def preload_on_repo_change(repo_id: str, device_map: str, dtype_name: str, logs_state: str, model_state: Dict): """当仓库切换时,预加载模型并写日志。""" logs_state = logs_state or "" model_state, _, _, logs_state = ensure_model(model_state, repo_id, device_map, dtype_name, logs_state) return logs_state, model_state # ========================================== # Gradio UI # ========================================== with gr.Blocks(title="mini-moe Chat (Gradio)") as demo: messages_state = gr.State([]) # 保存 role/content 历史 logs_state = gr.State("") # 保存日志文本 model_state = gr.State({"repo_id": None, "tok": None, "model": None}) # 当前已加载模型 with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot(label="对话", height=520) with gr.Row(): user_box = gr.Textbox(label="输入", placeholder="请输入你的问题… (Shift+Enter 换行)", lines=3) with gr.Row(): send_btn = gr.Button("发送", variant="primary") clear_btn = gr.Button("清空对话") with gr.Column(scale=1): gr.Markdown("## ⚙️ 设置") repo_dd = gr.Dropdown( label="模型仓库 (HF repo)", choices=["caixiaoshun/mini-moe", "caixiaoshun/mini-llm"], value="caixiaoshun/mini-moe", ) device_dd = gr.Dropdown(label="device_map", choices=["cpu", "auto"], value="cpu") dtype_dd = gr.Dropdown(label="精度 (dtype/torch_dtype)", choices=["auto", "float32", "bfloat16", "float16"], value="auto") max_new_num = gr.Number(label="max_new_token", value=256, precision=0) top_k_num = gr.Number(label="top_k", value=5, precision=0) with gr.Accordion("📜 日志 (展开查看)", open=False): logs_box = gr.Textbox(label="运行日志", lines=12, interactive=False) log_clear_btn = gr.Button("清空日志") # Events: send / submit send_evt_inputs = [ user_box, messages_state, repo_dd, device_dd, dtype_dd, max_new_num, top_k_num, logs_state, model_state ] send_evt_outputs = [chatbot, messages_state, logs_box, logs_state, model_state] send_btn.click(predict, inputs=send_evt_inputs, outputs=send_evt_outputs) user_box.submit(predict, inputs=send_evt_inputs, outputs=send_evt_outputs) # Clear input after send def _clear_text(): return "" send_btn.click(_clear_text, inputs=None, outputs=user_box) user_box.submit(_clear_text, inputs=None, outputs=user_box) # Clear chat clear_btn.click(clear_chat, inputs=None, outputs=[chatbot, messages_state]) # Clear logs log_clear_btn.click(clear_logs_fn, inputs=None, outputs=[logs_box, logs_state]) # Preload on repo change (only reload on repo change) repo_dd.change(preload_on_repo_change, inputs=[repo_dd, device_dd, dtype_dd, logs_state, model_state], outputs=[logs_box, model_state]) if __name__ == "__main__": demo.queue().launch() # set share=True if you want a public link