Text2VectorSQL / app.py
zr-wang's picture
Initial commit
68aa6fa
raw
history blame
1.5 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from modelscope.hub.snapshot_download import snapshot_download
# 1. 确定您的基础模型 (Base Model)
BASE_MODEL_ID = "seeklhy/OmniSQL-7B"
# 2. 您的 LoRA 模型 ID
LORA_MODEL_ID = "risemds/UniVectorSQL-7B-LoRA-all_steps_1030_new_data"
# 3. 下载 ModelScope LoRA 权重
# (您可能需要先登录 modelscope: `from modelscope.hub.api import HubApi; api = HubApi(); api.login('YOUR_TOKEN')`)
lora_path = snapshot_download(LORA_MODEL_ID, revision='master') # 确保使用正确的 revision
# 4. 加载基础模型和 Tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
# 5. 加载并融合 LoRA 适配器
# PeftModel 会自动将 LoRA 权重加载到基础模型上
model = PeftModel.from_pretrained(model, lora_path)
# (可选) 如果需要,可以合并权重以加快推理
# model = model.merge_and_unload()
model.eval()
# 6. 定义推理函数
def inference(text_input):
inputs = tokenizer(text_input, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=100)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
return result
# 7. 创建 Gradio 界面
iface = gr.Interface(fn=inference, inputs="text", outputs="text")
iface.launch()