Spaces:
Running
Running
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| from typing import List, Dict, Optional, Tuple | |
| import time | |
| # ========================================== | |
| # Helper: dtype map & loader with simple cache | |
| # ========================================== | |
| DTYPE_MAP = { | |
| "auto": "auto", | |
| "float32": torch.float32, | |
| "bfloat16": torch.bfloat16, | |
| "float16": torch.float16, | |
| } | |
| _MODEL_CACHE = {} | |
| def _dtype_from_name(name: str): | |
| return DTYPE_MAP.get(name, "auto") | |
| def load_model_and_tokenizer(repo_id: str, device_map: str = "cpu", dtype_name: str = "auto"): | |
| """Load & cache (tokenizer, model) keyed by (repo_id, device_map, dtype). No low_cpu_mem_usage. | |
| Prefer `dtype=...`; on TypeError fallback to `torch_dtype=` or omit. | |
| """ | |
| key = (repo_id, device_map, dtype_name) | |
| if key in _MODEL_CACHE: | |
| return _MODEL_CACHE[key] | |
| tok = AutoTokenizer.from_pretrained(repo_id) | |
| dtype_val = _dtype_from_name(dtype_name) | |
| common_kwargs = dict(trust_remote_code=True, device_map=device_map) | |
| model = None | |
| try: | |
| if dtype_name == "auto": | |
| model = AutoModelForCausalLM.from_pretrained(repo_id, dtype="auto", **common_kwargs) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained(repo_id, dtype=dtype_val, **common_kwargs) | |
| except TypeError: | |
| if dtype_name == "auto": | |
| model = AutoModelForCausalLM.from_pretrained(repo_id, **common_kwargs) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=dtype_val, **common_kwargs) | |
| model.eval() | |
| _MODEL_CACHE[key] = (tok, model) | |
| return tok, model | |
| # ========================================== | |
| # Chat utilities & logging helpers | |
| # ========================================== | |
| def messages_to_pairs(messages: List[Dict[str, str]]) -> List[Tuple[str, str]]: | |
| """Convert role-based messages into (user, assistant) pairs for Gradio Chatbot.""" | |
| pairs: List[Tuple[str, str]] = [] | |
| i = 0 | |
| while i < len(messages): | |
| msg = messages[i] | |
| if msg.get("role") == "user": | |
| user = msg.get("content", "") | |
| if i + 1 < len(messages) and messages[i + 1].get("role") == "assistant": | |
| assistant = messages[i + 1].get("content", "") | |
| pairs.append((user, assistant)) | |
| i += 2 | |
| else: | |
| pairs.append((user, "")) | |
| i += 1 | |
| else: | |
| # Skip unexpected assistant-first cases | |
| i += 1 | |
| return pairs | |
| def _ts() -> str: | |
| return time.strftime("%H:%M:%S") | |
| def append_log(logs: str, msg: str) -> str: | |
| line = f"[{_ts()}] {msg}\n" | |
| return (logs + line) if logs else line | |
| # ========================================== | |
| # Model state helpers (reload only when repo_id changes) | |
| # ========================================== | |
| def ensure_model(model_state: Dict, repo_id: str, device_map: str, dtype_name: str, logs: str): | |
| ms = model_state or {"repo_id": None, "tok": None, "model": None} | |
| if ms.get("repo_id") != repo_id or ms.get("model") is None: | |
| logs = append_log(logs, f"加载模型 {repo_id}(触发:repo 变更)…") | |
| tok, mdl = load_model_and_tokenizer(repo_id, device_map=device_map, dtype_name=dtype_name) | |
| ms = {"repo_id": repo_id, "tok": tok, "model": mdl} | |
| logs = append_log(logs, "模型加载完成。") | |
| else: | |
| logs = append_log(logs, f"使用已加载模型 {repo_id}(缓存)") | |
| return ms, ms["tok"], ms["model"], logs | |
| # ========================================== | |
| # Predict | |
| # ========================================== | |
| def predict(user_text: str, | |
| messages_state: List[Dict[str, str]], | |
| repo_id: str, device_map: str, dtype_name: str, | |
| max_new_token: int, top_k: int, | |
| logs_state: str, | |
| model_state: Dict): | |
| messages_state = messages_state or [] | |
| logs_state = logs_state or "" | |
| # 1) Ensure model based on repo_id only | |
| model_state, tokenizer, model, logs_state = ensure_model(model_state, repo_id, device_map, dtype_name, logs_state) | |
| # 2) Append user & paint | |
| messages_state.append({"role": "user", "content": user_text or ""}) | |
| logs_state = append_log(logs_state, f"收到输入:{(user_text or '').strip()[:60]}") | |
| chat_display = messages_to_pairs(messages_state) | |
| yield chat_display, messages_state, logs_state, logs_state, model_state | |
| # 3) Inference | |
| try: | |
| logs_state = append_log(logs_state, f"开始推理:max_new_token={int(max_new_token)}, top_k={int(top_k)}") | |
| yield chat_display, messages_state, logs_state, logs_state, model_state | |
| try: | |
| output = model.chat( | |
| messages_state, | |
| tokenizer, | |
| max_new_token=int(max_new_token), | |
| top_k=int(top_k), | |
| ) | |
| except TypeError: | |
| output = model.chat(messages_state, tokenizer) | |
| partial = "" | |
| for ch in str(output): | |
| partial += ch | |
| chat_display[-1] = (chat_display[-1][0], partial) | |
| yield chat_display, messages_state, logs_state, logs_state, model_state | |
| messages_state.append({"role": "assistant", "content": str(output)}) | |
| logs_state = append_log(logs_state, f"推理完成,输出长度 {len(str(output))} 字符。") | |
| yield chat_display, messages_state, logs_state, logs_state, model_state | |
| except Exception as e: | |
| err = f"[推理错误] {e}" | |
| logs_state = append_log(logs_state, err) | |
| chat_display[-1] = (chat_display[-1][0], err) | |
| messages_state.append({"role": "assistant", "content": err}) | |
| yield chat_display, messages_state, logs_state, logs_state, model_state | |
| def clear_chat(): | |
| return [], [] # chatbot pairs, messages_state | |
| def clear_logs_fn(): | |
| return "", "" # logs_box text, logs_state | |
| def preload_on_repo_change(repo_id: str, device_map: str, dtype_name: str, logs_state: str, model_state: Dict): | |
| """当仓库切换时,预加载模型并写日志。""" | |
| logs_state = logs_state or "" | |
| model_state, _, _, logs_state = ensure_model(model_state, repo_id, device_map, dtype_name, logs_state) | |
| return logs_state, model_state | |
| # ========================================== | |
| # Gradio UI | |
| # ========================================== | |
| with gr.Blocks(title="mini-moe Chat (Gradio)") as demo: | |
| messages_state = gr.State([]) # 保存 role/content 历史 | |
| logs_state = gr.State("") # 保存日志文本 | |
| model_state = gr.State({"repo_id": None, "tok": None, "model": None}) # 当前已加载模型 | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(label="对话", height=520) | |
| with gr.Row(): | |
| user_box = gr.Textbox(label="输入", placeholder="请输入你的问题… (Shift+Enter 换行)", lines=3) | |
| with gr.Row(): | |
| send_btn = gr.Button("发送", variant="primary") | |
| clear_btn = gr.Button("清空对话") | |
| with gr.Column(scale=1): | |
| gr.Markdown("## ⚙️ 设置") | |
| repo_dd = gr.Dropdown( | |
| label="模型仓库 (HF repo)", | |
| choices=["caixiaoshun/mini-moe", "caixiaoshun/mini-llm"], | |
| value="caixiaoshun/mini-moe", | |
| ) | |
| device_dd = gr.Dropdown(label="device_map", choices=["cpu", "auto"], value="cpu") | |
| dtype_dd = gr.Dropdown(label="精度 (dtype/torch_dtype)", choices=["auto", "float32", "bfloat16", "float16"], value="auto") | |
| max_new_num = gr.Number(label="max_new_token", value=256, precision=0) | |
| top_k_num = gr.Number(label="top_k", value=5, precision=0) | |
| with gr.Accordion("📜 日志 (展开查看)", open=False): | |
| logs_box = gr.Textbox(label="运行日志", lines=12, interactive=False) | |
| log_clear_btn = gr.Button("清空日志") | |
| # Events: send / submit | |
| send_evt_inputs = [ | |
| user_box, messages_state, repo_dd, device_dd, dtype_dd, max_new_num, top_k_num, logs_state, model_state | |
| ] | |
| send_evt_outputs = [chatbot, messages_state, logs_box, logs_state, model_state] | |
| send_btn.click(predict, inputs=send_evt_inputs, outputs=send_evt_outputs) | |
| user_box.submit(predict, inputs=send_evt_inputs, outputs=send_evt_outputs) | |
| # Clear input after send | |
| def _clear_text(): | |
| return "" | |
| send_btn.click(_clear_text, inputs=None, outputs=user_box) | |
| user_box.submit(_clear_text, inputs=None, outputs=user_box) | |
| # Clear chat | |
| clear_btn.click(clear_chat, inputs=None, outputs=[chatbot, messages_state]) | |
| # Clear logs | |
| log_clear_btn.click(clear_logs_fn, inputs=None, outputs=[logs_box, logs_state]) | |
| # Preload on repo change (only reload on repo change) | |
| repo_dd.change(preload_on_repo_change, | |
| inputs=[repo_dd, device_dd, dtype_dd, logs_state, model_state], | |
| outputs=[logs_box, model_state]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() # set share=True if you want a public link | |