Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig | |
| # --------------------------------------------------------- | |
| MODEL_ID = "Yangyang1205/MobileLLM" | |
| # --------------------------------------------------------- | |
| model_loaded = False | |
| load_error_msg = "" | |
| def load_model(): | |
| global model_loaded, load_error_msg | |
| print(f"🚀 正在启动... 准备加载模型: {MODEL_ID}") | |
| try: | |
| # 1. 强制修正 Config | |
| config = AutoConfig.from_pretrained(MODEL_ID) | |
| config.tie_word_embeddings = True | |
| # 2. 加载 | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| config=config, | |
| use_safetensors=False, | |
| trust_remote_code=True | |
| ) | |
| model = model.to("cpu") | |
| model.eval() | |
| model_loaded = True | |
| return tokenizer, model | |
| except Exception as e: | |
| model_loaded = False | |
| load_error_msg = str(e) | |
| return None, None | |
| tokenizer, model = load_model() | |
| # --- 核心生成函数 --- | |
| def generate_text(prompt, max_len, temp): | |
| if not model_loaded: | |
| return f"模型未加载: {load_error_msg}" | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_len, | |
| do_sample=True, | |
| temperature=temp, | |
| repetition_penalty=1.2, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # 解码全部内容 | |
| full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # 结果清洗:去掉输入部分 | |
| new_text = full_response[len(prompt):] | |
| # 去掉未完成的句子 | |
| if "\n" in new_text.strip(): | |
| lines = [line for line in new_text.split('\n') if line.strip()] | |
| if lines: | |
| return lines[0] | |
| return new_text | |
| except Exception as e: | |
| return str(e) | |
| # --- ✨ 全新 UI 布局 (Blocks) ✨ --- | |
| # ⚠️ 修复点:去掉了 theme 和 title 参数,防止报错 | |
| with gr.Blocks() as demo: | |
| # 标题区域 | |
| gr.Markdown( | |
| """ | |
| # 📱 MobileLLM 80M 续写测试 | |
| 这是一个仅有 80M 参数的基座模型。它不会对话,但擅长**上下文模仿 (In-Context Learning)**。 | |
| """ | |
| ) | |
| # 左右分栏布局 | |
| with gr.Row(): | |
| # 左边:输入区 | |
| with gr.Column(): | |
| input_box = gr.Textbox( | |
| label="输入 Prompt (提示词)", | |
| lines=10, | |
| placeholder="在这里输入排比句...", | |
| value="The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is" | |
| ) | |
| # 高级参数 | |
| with gr.Accordion("⚙️ 高级参数", open=False): | |
| slider_len = gr.Slider(minimum=1, maximum=100, value=20, label="生成长度", step=1) | |
| slider_temp = gr.Slider(minimum=0.1, maximum=1.0, value=0.6, label="温度 (创造力)", step=0.1) | |
| submit_btn = gr.Button("🚀 开始生成", variant="primary") | |
| # 右边:输出区 | |
| with gr.Column(): | |
| output_box = gr.Textbox( | |
| label="模型续写结果", | |
| lines=10, | |
| interactive=False | |
| ) | |
| # 绑定点击事件 | |
| submit_btn.click( | |
| fn=generate_text, | |
| inputs=[input_box, slider_len, slider_temp], | |
| outputs=output_box | |
| ) | |
| # 底部示例 | |
| gr.Examples( | |
| examples=[ | |
| ["The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is"], | |
| ["Artificial Intelligence is a field of computer science that"], | |
| ["def add(a, b):\n return a + b\n\ndef multiply(a, b):"], | |
| ], | |
| inputs=input_box | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |