Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import os | |
| # --- 🛠️ 配置区域:在此定义您的固定模型列表 --- | |
| # 请将下方的字符串替换为您真实的 Hugging Face Model Repo IDs | |
| MODEL_LIST = [ | |
| "telecomadm1145/mamba2_exp3","telecomadm1145/mamba2_exp2" | |
| ] | |
| # 默认选中的模型 | |
| DEFAULT_MODEL = MODEL_LIST[0] | |
| # --- 全局缓存 --- | |
| # 结构: { "model_id": (model, tokenizer) } | |
| MODEL_CACHE = {} | |
| def get_device(): | |
| """检测运行环境""" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def load_model(model_id): | |
| """ | |
| 加载并缓存模型。 | |
| 切换模型时会清空旧缓存以节省显存。 | |
| """ | |
| global MODEL_CACHE | |
| # 检查是否需要重新加载 | |
| if model_id not in MODEL_CACHE: | |
| # 清理旧缓存(简单的单例缓存策略,防止显存爆满) | |
| if MODEL_CACHE: | |
| print(f"正在切换模型,清理旧缓存...") | |
| MODEL_CACHE.clear() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"正在加载模型: {model_id} ...") | |
| try: | |
| device = get_device() | |
| # 加载分词器 | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| trust_remote_code=True | |
| ) | |
| # 加载模型 | |
| # 自动选择精度: GPU用 float16, CPU用 float32 | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map=device, | |
| torch_dtype=dtype, | |
| trust_remote_code=True | |
| ) | |
| MODEL_CACHE[model_id] = (model, tokenizer) | |
| print(f"✅ 模型 {model_id} 加载成功!") | |
| except Exception as e: | |
| return None, None, f"❌ 模型加载失败: {str(e)}" | |
| return MODEL_CACHE[model_id][0], MODEL_CACHE[model_id][1], None | |
| def generate_text( | |
| model_selector, | |
| prompt, | |
| max_new_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty | |
| ): | |
| """ | |
| 文本生成逻辑 | |
| """ | |
| if not prompt.strip(): | |
| return "⚠️ 请输入提示词 (Prompt)。" | |
| if not model_selector: | |
| return "⚠️ 请选择一个模型。" | |
| # 1. 加载模型 (带缓存) | |
| model, tokenizer, error_msg = load_model(model_selector) | |
| if error_msg: | |
| return error_msg | |
| # 2. 准备输入 | |
| device = model.device | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| # 3. 生成参数处理 | |
| # 如果 temperature 为 0,通常意味着贪婪搜索(greedy search),但在 transformers 中建议保持 do_sample=False | |
| do_sample = temperature > 1e-5 | |
| try: | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=do_sample, | |
| temperature=temperature if do_sample else 1.0, | |
| top_p=top_p, | |
| top_k=int(top_k), | |
| repetition_penalty=repetition_penalty, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # 4. 解码 | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return generated_text | |
| except Exception as e: | |
| return f"❌ 生成过程中发生错误: {str(e)}" | |
| # --- Gradio 界面构建 --- | |
| with gr.Blocks(title="Mamba2 模型实验室") as demo: | |
| gr.Markdown("# 🐍 Mamba2 Text Generation Space") | |
| gr.Markdown("请从列表中选择模型,并在本地环境中进行推理测试。") | |
| with gr.Row(): | |
| # 左侧控制栏 | |
| with gr.Column(scale=1, min_width=300): | |
| # --- 修改点:改为下拉菜单 --- | |
| model_selector = gr.Dropdown( | |
| choices=MODEL_LIST, | |
| value=DEFAULT_MODEL, | |
| label="选择模型 (Select Model)", | |
| info="从预设列表中选择要加载的模型", | |
| interactive=True | |
| ) | |
| with gr.Accordion("⚙️ 采样参数 (Sampling Params)", open=True): | |
| max_tokens = gr.Slider(minimum=1, maximum=128, value=32, step=1, label="最大生成长度 (Max New Tokens)") | |
| temp = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="温度 (Temperature)") | |
| top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.01, label="Top-P") | |
| top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K") | |
| rep_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="重复惩罚 (Repetition Penalty)") | |
| gr.Markdown(""" | |
| **注意**: | |
| 切换模型时,系统会自动下载新模型并加载到内存,初次切换可能需要一些时间。 | |
| """) | |
| # 右侧输入输出栏 | |
| with gr.Column(scale=2): | |
| input_text = gr.TextArea( | |
| label="输入提示词 (Prompt)", | |
| ) | |
| generate_btn = gr.Button("🚀 开始生成 (Generate)", variant="primary", size="lg") | |
| output_text = gr.TextArea( | |
| label="生成结果 (Output)", | |
| interactive=False, | |
| ) | |
| # 绑定事件 | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=[ | |
| model_selector, | |
| input_text, | |
| max_tokens, | |
| temp, | |
| top_p_slider, | |
| top_k_slider, | |
| rep_penalty | |
| ], | |
| outputs=output_text | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |