MobileLLM-chat / app.py
Yangyang1205's picture
Update app.py
c0a0dd2 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
# ---------------------------------------------------------
MODEL_ID = "Yangyang1205/MobileLLM"
# ---------------------------------------------------------
model_loaded = False
load_error_msg = ""
def load_model():
global model_loaded, load_error_msg
print(f"🚀 正在启动... 准备加载模型: {MODEL_ID}")
try:
# 1. 强制修正 Config
config = AutoConfig.from_pretrained(MODEL_ID)
config.tie_word_embeddings = True
# 2. 加载
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
config=config,
use_safetensors=False,
trust_remote_code=True
)
model = model.to("cpu")
model.eval()
model_loaded = True
return tokenizer, model
except Exception as e:
model_loaded = False
load_error_msg = str(e)
return None, None
tokenizer, model = load_model()
# --- 核心生成函数 ---
def generate_text(prompt, max_len, temp):
if not model_loaded:
return f"模型未加载: {load_error_msg}"
try:
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_len,
do_sample=True,
temperature=temp,
repetition_penalty=1.2,
pad_token_id=tokenizer.eos_token_id
)
# 解码全部内容
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 结果清洗:去掉输入部分
new_text = full_response[len(prompt):]
# 去掉未完成的句子
if "\n" in new_text.strip():
lines = [line for line in new_text.split('\n') if line.strip()]
if lines:
return lines[0]
return new_text
except Exception as e:
return str(e)
# --- ✨ 全新 UI 布局 (Blocks) ✨ ---
# ⚠️ 修复点:去掉了 theme 和 title 参数,防止报错
with gr.Blocks() as demo:
# 标题区域
gr.Markdown(
"""
# 📱 MobileLLM 80M 续写测试
这是一个仅有 80M 参数的基座模型。它不会对话,但擅长**上下文模仿 (In-Context Learning)**。
"""
)
# 左右分栏布局
with gr.Row():
# 左边:输入区
with gr.Column():
input_box = gr.Textbox(
label="输入 Prompt (提示词)",
lines=10,
placeholder="在这里输入排比句...",
value="The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is"
)
# 高级参数
with gr.Accordion("⚙️ 高级参数", open=False):
slider_len = gr.Slider(minimum=1, maximum=100, value=20, label="生成长度", step=1)
slider_temp = gr.Slider(minimum=0.1, maximum=1.0, value=0.6, label="温度 (创造力)", step=0.1)
submit_btn = gr.Button("🚀 开始生成", variant="primary")
# 右边:输出区
with gr.Column():
output_box = gr.Textbox(
label="模型续写结果",
lines=10,
interactive=False
)
# 绑定点击事件
submit_btn.click(
fn=generate_text,
inputs=[input_box, slider_len, slider_temp],
outputs=output_box
)
# 底部示例
gr.Examples(
examples=[
["The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is"],
["Artificial Intelligence is a field of computer science that"],
["def add(a, b):\n return a + b\n\ndef multiply(a, b):"],
],
inputs=input_box
)
if __name__ == "__main__":
demo.launch()