import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import os # --- 🛠️ 配置区域:在此定义您的固定模型列表 --- # 请将下方的字符串替换为您真实的 Hugging Face Model Repo IDs MODEL_LIST = [ "telecomadm1145/mamba2_exp3","telecomadm1145/mamba2_exp2" ] # 默认选中的模型 DEFAULT_MODEL = MODEL_LIST[0] # --- 全局缓存 --- # 结构: { "model_id": (model, tokenizer) } MODEL_CACHE = {} def get_device(): """检测运行环境""" if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): return "mps" return "cpu" def load_model(model_id): """ 加载并缓存模型。 切换模型时会清空旧缓存以节省显存。 """ global MODEL_CACHE # 检查是否需要重新加载 if model_id not in MODEL_CACHE: # 清理旧缓存(简单的单例缓存策略,防止显存爆满) if MODEL_CACHE: print(f"正在切换模型,清理旧缓存...") MODEL_CACHE.clear() if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"正在加载模型: {model_id} ...") try: device = get_device() # 加载分词器 tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True ) # 加载模型 # 自动选择精度: GPU用 float16, CPU用 float32 dtype = torch.float16 if device == "cuda" else torch.float32 model = AutoModelForCausalLM.from_pretrained( model_id, device_map=device, torch_dtype=dtype, trust_remote_code=True ) MODEL_CACHE[model_id] = (model, tokenizer) print(f"✅ 模型 {model_id} 加载成功!") except Exception as e: return None, None, f"❌ 模型加载失败: {str(e)}" return MODEL_CACHE[model_id][0], MODEL_CACHE[model_id][1], None def generate_text( model_selector, prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty ): """ 文本生成逻辑 """ if not prompt.strip(): return "⚠️ 请输入提示词 (Prompt)。" if not model_selector: return "⚠️ 请选择一个模型。" # 1. 加载模型 (带缓存) model, tokenizer, error_msg = load_model(model_selector) if error_msg: return error_msg # 2. 准备输入 device = model.device inputs = tokenizer(prompt, return_tensors="pt").to(device) # 3. 生成参数处理 # 如果 temperature 为 0,通常意味着贪婪搜索(greedy search),但在 transformers 中建议保持 do_sample=False do_sample = temperature > 1e-5 try: with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=int(max_new_tokens), do_sample=do_sample, temperature=temperature if do_sample else 1.0, top_p=top_p, top_k=int(top_k), repetition_penalty=repetition_penalty, pad_token_id=tokenizer.eos_token_id ) # 4. 解码 generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text except Exception as e: return f"❌ 生成过程中发生错误: {str(e)}" # --- Gradio 界面构建 --- with gr.Blocks(title="Mamba2 模型实验室") as demo: gr.Markdown("# 🐍 Mamba2 Text Generation Space") gr.Markdown("请从列表中选择模型,并在本地环境中进行推理测试。") with gr.Row(): # 左侧控制栏 with gr.Column(scale=1, min_width=300): # --- 修改点:改为下拉菜单 --- model_selector = gr.Dropdown( choices=MODEL_LIST, value=DEFAULT_MODEL, label="选择模型 (Select Model)", info="从预设列表中选择要加载的模型", interactive=True ) with gr.Accordion("⚙️ 采样参数 (Sampling Params)", open=True): max_tokens = gr.Slider(minimum=1, maximum=128, value=32, step=1, label="最大生成长度 (Max New Tokens)") temp = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="温度 (Temperature)") top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.01, label="Top-P") top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K") rep_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="重复惩罚 (Repetition Penalty)") gr.Markdown(""" **注意**: 切换模型时,系统会自动下载新模型并加载到内存,初次切换可能需要一些时间。 """) # 右侧输入输出栏 with gr.Column(scale=2): input_text = gr.TextArea( label="输入提示词 (Prompt)", ) generate_btn = gr.Button("🚀 开始生成 (Generate)", variant="primary", size="lg") output_text = gr.TextArea( label="生成结果 (Output)", interactive=False, ) # 绑定事件 generate_btn.click( fn=generate_text, inputs=[ model_selector, input_text, max_tokens, temp, top_p_slider, top_k_slider, rep_penalty ], outputs=output_text ) if __name__ == "__main__": demo.queue().launch()