studyV2 / app.py
ZhongZhiYY's picture
Create app.py
683a0e5 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
import torch
# 1. 选择基座模型
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3" # 你也可以改成 chatglm、qwen 等
LORA_WEIGHTS = "./lora-weights" # 如果你把权重推到 HF Hub,可以写成 "your-username/your-model"
# 2. 加载模型 & LoRA
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
model = PeftModel.from_pretrained(base_model, LORA_WEIGHTS)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=0 if device == "cuda" else -1
)
# 3. 聊天函数
def chat_fn(history, user_input):
prompt = ""
for msg in history:
prompt += f"用户: {msg[0]}\n助手: {msg[1]}\n"
prompt += f"用户: {user_input}\n助手:"
outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9)
answer = outputs[0]["generated_text"].split("助手:")[-1].strip()
history.append((user_input, answer))
return history, history
# 4. Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🤖 测试你自己的 LoRA 大模型")
chatbot = gr.Chatbot(height=400)
msg = gr.Textbox(label="输入你的问题")
clear = gr.Button("清空对话")
state = gr.State([])
msg.submit(chat_fn, [state, msg], [chatbot, state])
clear.click(lambda: ([], []), None, [chatbot, state])
demo.launch()