Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| from peft import PeftModel | |
| import torch | |
| # 1. 选择基座模型 | |
| BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3" # 你也可以改成 chatglm、qwen 等 | |
| LORA_WEIGHTS = "./lora-weights" # 如果你把权重推到 HF Hub,可以写成 "your-username/your-model" | |
| # 2. 加载模型 & LoRA | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| model = PeftModel.from_pretrained(base_model, LORA_WEIGHTS) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=0 if device == "cuda" else -1 | |
| ) | |
| # 3. 聊天函数 | |
| def chat_fn(history, user_input): | |
| prompt = "" | |
| for msg in history: | |
| prompt += f"用户: {msg[0]}\n助手: {msg[1]}\n" | |
| prompt += f"用户: {user_input}\n助手:" | |
| outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9) | |
| answer = outputs[0]["generated_text"].split("助手:")[-1].strip() | |
| history.append((user_input, answer)) | |
| return history, history | |
| # 4. Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🤖 测试你自己的 LoRA 大模型") | |
| chatbot = gr.Chatbot(height=400) | |
| msg = gr.Textbox(label="输入你的问题") | |
| clear = gr.Button("清空对话") | |
| state = gr.State([]) | |
| msg.submit(chat_fn, [state, msg], [chatbot, state]) | |
| clear.click(lambda: ([], []), None, [chatbot, state]) | |
| demo.launch() | |