mamba2exp / app.py
telecomadm1145's picture
Update app.py
3efc600 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
# --- 🛠️ 配置区域:在此定义您的固定模型列表 ---
# 请将下方的字符串替换为您真实的 Hugging Face Model Repo IDs
MODEL_LIST = [
"telecomadm1145/mamba2_exp3","telecomadm1145/mamba2_exp2"
]
# 默认选中的模型
DEFAULT_MODEL = MODEL_LIST[0]
# --- 全局缓存 ---
# 结构: { "model_id": (model, tokenizer) }
MODEL_CACHE = {}
def get_device():
"""检测运行环境"""
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
return "cpu"
def load_model(model_id):
"""
加载并缓存模型。
切换模型时会清空旧缓存以节省显存。
"""
global MODEL_CACHE
# 检查是否需要重新加载
if model_id not in MODEL_CACHE:
# 清理旧缓存(简单的单例缓存策略,防止显存爆满)
if MODEL_CACHE:
print(f"正在切换模型,清理旧缓存...")
MODEL_CACHE.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"正在加载模型: {model_id} ...")
try:
device = get_device()
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
# 加载模型
# 自动选择精度: GPU用 float16, CPU用 float32
dtype = torch.float16 if device == "cuda" else torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device,
torch_dtype=dtype,
trust_remote_code=True
)
MODEL_CACHE[model_id] = (model, tokenizer)
print(f"✅ 模型 {model_id} 加载成功!")
except Exception as e:
return None, None, f"❌ 模型加载失败: {str(e)}"
return MODEL_CACHE[model_id][0], MODEL_CACHE[model_id][1], None
def generate_text(
model_selector,
prompt,
max_new_tokens,
temperature,
top_p,
top_k,
repetition_penalty
):
"""
文本生成逻辑
"""
if not prompt.strip():
return "⚠️ 请输入提示词 (Prompt)。"
if not model_selector:
return "⚠️ 请选择一个模型。"
# 1. 加载模型 (带缓存)
model, tokenizer, error_msg = load_model(model_selector)
if error_msg:
return error_msg
# 2. 准备输入
device = model.device
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# 3. 生成参数处理
# 如果 temperature 为 0,通常意味着贪婪搜索(greedy search),但在 transformers 中建议保持 do_sample=False
do_sample = temperature > 1e-5
try:
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=do_sample,
temperature=temperature if do_sample else 1.0,
top_p=top_p,
top_k=int(top_k),
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id
)
# 4. 解码
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
except Exception as e:
return f"❌ 生成过程中发生错误: {str(e)}"
# --- Gradio 界面构建 ---
with gr.Blocks(title="Mamba2 模型实验室") as demo:
gr.Markdown("# 🐍 Mamba2 Text Generation Space")
gr.Markdown("请从列表中选择模型,并在本地环境中进行推理测试。")
with gr.Row():
# 左侧控制栏
with gr.Column(scale=1, min_width=300):
# --- 修改点:改为下拉菜单 ---
model_selector = gr.Dropdown(
choices=MODEL_LIST,
value=DEFAULT_MODEL,
label="选择模型 (Select Model)",
info="从预设列表中选择要加载的模型",
interactive=True
)
with gr.Accordion("⚙️ 采样参数 (Sampling Params)", open=True):
max_tokens = gr.Slider(minimum=1, maximum=128, value=32, step=1, label="最大生成长度 (Max New Tokens)")
temp = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="温度 (Temperature)")
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.01, label="Top-P")
top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K")
rep_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="重复惩罚 (Repetition Penalty)")
gr.Markdown("""
**注意**:
切换模型时,系统会自动下载新模型并加载到内存,初次切换可能需要一些时间。
""")
# 右侧输入输出栏
with gr.Column(scale=2):
input_text = gr.TextArea(
label="输入提示词 (Prompt)",
)
generate_btn = gr.Button("🚀 开始生成 (Generate)", variant="primary", size="lg")
output_text = gr.TextArea(
label="生成结果 (Output)",
interactive=False,
)
# 绑定事件
generate_btn.click(
fn=generate_text,
inputs=[
model_selector,
input_text,
max_tokens,
temp,
top_p_slider,
top_k_slider,
rep_penalty
],
outputs=output_text
)
if __name__ == "__main__":
demo.queue().launch()