Text2VectorSQL / app.py
dongwenyao's picture
vllm inference
b824726
import gradio as gr
import torch
# 导入 vllm 相关的库
from vllm import LLM, SamplingParams
# 保持 AutoTokenizer 用于处理聊天模板
from transformers import AutoTokenizer
from modelscope.hub.snapshot_download import snapshot_download
import os
import traceback
# --- 关键配置 ---
# !!! (重要) 请在此处填入您在 ModelScope 上的 "已合并" 模型的 ID
# vLLM 将直接加载这个完整的模型
#
# 我使用了您原始的基础模型 ID 作为示例,但您*必须*将其替换为
# 已经合并了 LoRA (UniVectorSQL) 的那个模型的 ID。
MERGED_MODEL_ID = "zrwang/UniVectorSQL-7B-LoRA-merged"
# 示例:如果您的合并后模型叫 "risemds/UniVectorSQL-7B-Merged",请修改上面一行
# -----------------
# --- vLLM 模型加载 ---
# 1. 下载模型
# 我们将忽略 .md 文件和 training_args.bin 文件
print(f"开始下载已合并的模型: {MERGED_MODEL_ID}")
model_dir = snapshot_download(
MERGED_MODEL_ID,
revision='master',
ignore_patterns=["*.md", "training_args.bin", "checkpoint-*"]
)
print(f"模型下载完成,路径: {model_dir}")
# 2. 加载 Tokenizer
# 我们需要 Tokenizer 来应用聊天模板
try:
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
print("从模型目录加载 Tokenizer 成功。")
except Exception as e:
print(f"从模型目录加载 Tokenizer 失败: {e}")
# 如果失败,程序无法继续,因为 vLLM 需要知道 EOT token
raise e
# --- (保留) 修复 Tokenizer 的 pad_token_id ---
# Qwen1.5 基础模型可能没有设置 pad_token_id
if tokenizer.pad_token_id is None:
print("Tokenizer 未设置 pad_token_id,将其设置为 eos_token_id。")
tokenizer.pad_token_id = tokenizer.eos_token_id
# (注意:我们不再需要修复 model.config,因为 vLLM 会处理)
# ----------------------------------------
# 3. 加载 vLLM 模型
print(f"开始加载 vLLM 模型: {model_dir}")
# 自动检测 GPU 数量以设置 tensor_parallel_size
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
print(f"检测到 {gpu_count} 个 CUDA GPU。")
else:
print("警告: 未检测到 CUDA GPU。vLLM 强烈建议在 GPU 上运行。")
gpu_count = 1 # 假设至少为 1,vLLM 0.4.0+ 也支持 CPU (但很慢)
llm = LLM(
model=model_dir,
trust_remote_code=True,
tensor_parallel_size=gpu_count, # 自动使用所有可用的 GPU
dtype="auto" # vLLM 会自动选择 (例如 bfloat16 或 float16)
# max_model_len=4096 # (可选) 如果需要,设置最大上下文长度
)
print("vLLM 模型加载完成。")
# 4. 定义推理函数 (*** 已修改为使用 vLLM ***)
# --- 提前准备 EOT token 和 SamplingParams ---
# 1. 获取 EOT (End of Text) token IDs (用于停止生成)
# Qwen 系列使用 <|im_end|> (151645) 和/或 <|endoftext|> (151643)
eot_ids = [tokenizer.eos_token_id]
im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
if im_end_token_id != tokenizer.unk_token_id and im_end_token_id not in eot_ids:
eot_ids.append(im_end_token_id)
print(f"vLLM 将使用 stop_token_ids: {eot_ids} (eos_token_id: {tokenizer.eos_token_id})")
# 2. 创建 vLLM SamplingParams
# (原 HFace generate 参数:max_new_tokens=2048, eos_token_id=eot_ids)
sampling_params = SamplingParams(
max_tokens=2048, # 对应 HFace 的 max_new_tokens
stop_token_ids=eot_ids, # <--- 关键:告诉 vLLM 何时停止
temperature=0.0, # 对于 SQL 生成,使用贪婪采样 (0.0) 通常是最好的
top_p=1.0, # (配合 temperature=0.0)
)
# ----------------------------------------
def inference(text_input):
print(f"收到输入: {text_input}")
try:
# 1. 构建 Qwen (ChatML) 对话格式
messages = [
{"role": "user", "content": text_input}
]
# 2. 应用对话模板
# vLLM 需要的是 *字符串*,而不是 token IDs
# tokenize=False 会返回格式化后的字符串
# add_generation_prompt=True 会在末尾添加 <|im_start|>assistant\n
prompt_str = tokenizer.apply_chat_template(
messages,
tokenize=False, # <--- 关键:返回字符串
add_generation_prompt=True # <--- 关键:添加 assistant 提示
)
# 3. vLLM 生成
# llm.generate 接受一个 prompt 列表
outputs = llm.generate([prompt_str], sampling_params)
# 4. 提取结果
# outputs 是一个列表,对应输入的 prompt 列表
# outputs[0] 是第一个 prompt 的 RequestOutput
# outputs[0].outputs[0] 是第一个 (best_of=1) 的生成结果
# .text 包含 *新生成* 的文本 (不包含 prompt)
result = outputs[0].outputs[0].text.strip()
print(f"生成结果 (仅新内容): {result}")
return result
except Exception as e:
print(f"vLLM 推理时出错: {e}")
# 打印更详细的 vLLM 错误
traceback.print_exc()
return f"错误: {e}"
# ----------------------------------------------------
# (示例保持不变)
example = """You are a senior SQL engineer. Your task is to generate a single, correct, and executable SQL query to answer the user's question based on the provided database context.
... (示例内容和之前一样) ...
Let's think step by step!
"""
# (为了简洁,省略了示例的完整内容,使用您原始的 'example' 变量即可)
examples = [example]
# 6. 创建 Gradio 界面 (保持不变)
iface = gr.Interface(
fn=inference,
inputs=gr.Textbox(lines=10, label="输入查询 (Input Query)"),
outputs=gr.Textbox(lines=10, label="模型输出 (Model Output)"),
title="UniVectorSQL-7B (vLLM 推理)",
# description="这是一个 Text-to-SQL 模型。请输入您的问题 (Question) 和数据库模式 (Schema)。点击下方示例尝试。",
examples=examples
)
# ---------------------
print("启动 Gradio 界面...")
# 在 Hugging Face 或 ModelScope Space 中,share=True 不是必需的
iface.launch()