Spaces:
Runtime error
Runtime error
File size: 6,156 Bytes
5cf6351 68aa6fa b824726 68aa6fa d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 5cf6351 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 5cf6351 b824726 5cf6351 68aa6fa d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 b824726 d0ac527 5cf6351 d0ac527 68aa6fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import gradio as gr
import torch
# 导入 vllm 相关的库
from vllm import LLM, SamplingParams
# 保持 AutoTokenizer 用于处理聊天模板
from transformers import AutoTokenizer
from modelscope.hub.snapshot_download import snapshot_download
import os
import traceback
# --- 关键配置 ---
# !!! (重要) 请在此处填入您在 ModelScope 上的 "已合并" 模型的 ID
# vLLM 将直接加载这个完整的模型
#
# 我使用了您原始的基础模型 ID 作为示例,但您*必须*将其替换为
# 已经合并了 LoRA (UniVectorSQL) 的那个模型的 ID。
MERGED_MODEL_ID = "zrwang/UniVectorSQL-7B-LoRA-merged"
# 示例:如果您的合并后模型叫 "risemds/UniVectorSQL-7B-Merged",请修改上面一行
# -----------------
# --- vLLM 模型加载 ---
# 1. 下载模型
# 我们将忽略 .md 文件和 training_args.bin 文件
print(f"开始下载已合并的模型: {MERGED_MODEL_ID}")
model_dir = snapshot_download(
MERGED_MODEL_ID,
revision='master',
ignore_patterns=["*.md", "training_args.bin", "checkpoint-*"]
)
print(f"模型下载完成,路径: {model_dir}")
# 2. 加载 Tokenizer
# 我们需要 Tokenizer 来应用聊天模板
try:
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
print("从模型目录加载 Tokenizer 成功。")
except Exception as e:
print(f"从模型目录加载 Tokenizer 失败: {e}")
# 如果失败,程序无法继续,因为 vLLM 需要知道 EOT token
raise e
# --- (保留) 修复 Tokenizer 的 pad_token_id ---
# Qwen1.5 基础模型可能没有设置 pad_token_id
if tokenizer.pad_token_id is None:
print("Tokenizer 未设置 pad_token_id,将其设置为 eos_token_id。")
tokenizer.pad_token_id = tokenizer.eos_token_id
# (注意:我们不再需要修复 model.config,因为 vLLM 会处理)
# ----------------------------------------
# 3. 加载 vLLM 模型
print(f"开始加载 vLLM 模型: {model_dir}")
# 自动检测 GPU 数量以设置 tensor_parallel_size
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
print(f"检测到 {gpu_count} 个 CUDA GPU。")
else:
print("警告: 未检测到 CUDA GPU。vLLM 强烈建议在 GPU 上运行。")
gpu_count = 1 # 假设至少为 1,vLLM 0.4.0+ 也支持 CPU (但很慢)
llm = LLM(
model=model_dir,
trust_remote_code=True,
tensor_parallel_size=gpu_count, # 自动使用所有可用的 GPU
dtype="auto" # vLLM 会自动选择 (例如 bfloat16 或 float16)
# max_model_len=4096 # (可选) 如果需要,设置最大上下文长度
)
print("vLLM 模型加载完成。")
# 4. 定义推理函数 (*** 已修改为使用 vLLM ***)
# --- 提前准备 EOT token 和 SamplingParams ---
# 1. 获取 EOT (End of Text) token IDs (用于停止生成)
# Qwen 系列使用 <|im_end|> (151645) 和/或 <|endoftext|> (151643)
eot_ids = [tokenizer.eos_token_id]
im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
if im_end_token_id != tokenizer.unk_token_id and im_end_token_id not in eot_ids:
eot_ids.append(im_end_token_id)
print(f"vLLM 将使用 stop_token_ids: {eot_ids} (eos_token_id: {tokenizer.eos_token_id})")
# 2. 创建 vLLM SamplingParams
# (原 HFace generate 参数:max_new_tokens=2048, eos_token_id=eot_ids)
sampling_params = SamplingParams(
max_tokens=2048, # 对应 HFace 的 max_new_tokens
stop_token_ids=eot_ids, # <--- 关键:告诉 vLLM 何时停止
temperature=0.0, # 对于 SQL 生成,使用贪婪采样 (0.0) 通常是最好的
top_p=1.0, # (配合 temperature=0.0)
)
# ----------------------------------------
def inference(text_input):
print(f"收到输入: {text_input}")
try:
# 1. 构建 Qwen (ChatML) 对话格式
messages = [
{"role": "user", "content": text_input}
]
# 2. 应用对话模板
# vLLM 需要的是 *字符串*,而不是 token IDs
# tokenize=False 会返回格式化后的字符串
# add_generation_prompt=True 会在末尾添加 <|im_start|>assistant\n
prompt_str = tokenizer.apply_chat_template(
messages,
tokenize=False, # <--- 关键:返回字符串
add_generation_prompt=True # <--- 关键:添加 assistant 提示
)
# 3. vLLM 生成
# llm.generate 接受一个 prompt 列表
outputs = llm.generate([prompt_str], sampling_params)
# 4. 提取结果
# outputs 是一个列表,对应输入的 prompt 列表
# outputs[0] 是第一个 prompt 的 RequestOutput
# outputs[0].outputs[0] 是第一个 (best_of=1) 的生成结果
# .text 包含 *新生成* 的文本 (不包含 prompt)
result = outputs[0].outputs[0].text.strip()
print(f"生成结果 (仅新内容): {result}")
return result
except Exception as e:
print(f"vLLM 推理时出错: {e}")
# 打印更详细的 vLLM 错误
traceback.print_exc()
return f"错误: {e}"
# ----------------------------------------------------
# (示例保持不变)
example = """You are a senior SQL engineer. Your task is to generate a single, correct, and executable SQL query to answer the user's question based on the provided database context.
... (示例内容和之前一样) ...
Let's think step by step!
"""
# (为了简洁,省略了示例的完整内容,使用您原始的 'example' 变量即可)
examples = [example]
# 6. 创建 Gradio 界面 (保持不变)
iface = gr.Interface(
fn=inference,
inputs=gr.Textbox(lines=10, label="输入查询 (Input Query)"),
outputs=gr.Textbox(lines=10, label="模型输出 (Model Output)"),
title="UniVectorSQL-7B (vLLM 推理)",
# description="这是一个 Text-to-SQL 模型。请输入您的问题 (Question) 和数据库模式 (Schema)。点击下方示例尝试。",
examples=examples
)
# ---------------------
print("启动 Gradio 界面...")
# 在 Hugging Face 或 ModelScope Space 中,share=True 不是必需的
iface.launch() |