import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig # --------------------------------------------------------- MODEL_ID = "Yangyang1205/MobileLLM" # --------------------------------------------------------- model_loaded = False load_error_msg = "" def load_model(): global model_loaded, load_error_msg print(f"🚀 正在启动... 准备加载模型: {MODEL_ID}") try: # 1. 强制修正 Config config = AutoConfig.from_pretrained(MODEL_ID) config.tie_word_embeddings = True # 2. 加载 tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, config=config, use_safetensors=False, trust_remote_code=True ) model = model.to("cpu") model.eval() model_loaded = True return tokenizer, model except Exception as e: model_loaded = False load_error_msg = str(e) return None, None tokenizer, model = load_model() # --- 核心生成函数 --- def generate_text(prompt, max_len, temp): if not model_loaded: return f"模型未加载: {load_error_msg}" try: inputs = tokenizer(prompt, return_tensors="pt") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_len, do_sample=True, temperature=temp, repetition_penalty=1.2, pad_token_id=tokenizer.eos_token_id ) # 解码全部内容 full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # 结果清洗:去掉输入部分 new_text = full_response[len(prompt):] # 去掉未完成的句子 if "\n" in new_text.strip(): lines = [line for line in new_text.split('\n') if line.strip()] if lines: return lines[0] return new_text except Exception as e: return str(e) # --- ✨ 全新 UI 布局 (Blocks) ✨ --- # ⚠️ 修复点:去掉了 theme 和 title 参数,防止报错 with gr.Blocks() as demo: # 标题区域 gr.Markdown( """ # 📱 MobileLLM 80M 续写测试 这是一个仅有 80M 参数的基座模型。它不会对话,但擅长**上下文模仿 (In-Context Learning)**。 """ ) # 左右分栏布局 with gr.Row(): # 左边:输入区 with gr.Column(): input_box = gr.Textbox( label="输入 Prompt (提示词)", lines=10, placeholder="在这里输入排比句...", value="The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is" ) # 高级参数 with gr.Accordion("⚙️ 高级参数", open=False): slider_len = gr.Slider(minimum=1, maximum=100, value=20, label="生成长度", step=1) slider_temp = gr.Slider(minimum=0.1, maximum=1.0, value=0.6, label="温度 (创造力)", step=0.1) submit_btn = gr.Button("🚀 开始生成", variant="primary") # 右边:输出区 with gr.Column(): output_box = gr.Textbox( label="模型续写结果", lines=10, interactive=False ) # 绑定点击事件 submit_btn.click( fn=generate_text, inputs=[input_box, slider_len, slider_temp], outputs=output_box ) # 底部示例 gr.Examples( examples=[ ["The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is"], ["Artificial Intelligence is a field of computer science that"], ["def add(a, b):\n return a + b\n\ndef multiply(a, b):"], ], inputs=input_box ) if __name__ == "__main__": demo.launch()