File size: 6,156 Bytes
5cf6351
68aa6fa
b824726
 
 
 
68aa6fa
d0ac527
b824726
d0ac527
 
 
b824726
 
 
 
 
 
 
d0ac527
 
 
 
b824726
 
 
d0ac527
b824726
 
 
d0ac527
b824726
d0ac527
b824726
d0ac527
b824726
 
 
d0ac527
b824726
 
d0ac527
b824726
 
 
5cf6351
b824726
 
d0ac527
 
 
b824726
d0ac527
 
 
b824726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0ac527
b824726
d0ac527
5cf6351
b824726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cf6351
68aa6fa
d0ac527
 
 
 
 
 
 
 
b824726
 
 
 
d0ac527
b824726
 
d0ac527
 
b824726
 
 
d0ac527
b824726
 
 
 
 
 
d0ac527
 
 
b824726
d0ac527
b824726
 
 
d0ac527
 
 
 
b824726
d0ac527
b824726
d0ac527
 
b824726
 
d0ac527
 
b824726
d0ac527
b824726
 
d0ac527
b824726
d0ac527
b824726
d0ac527
 
5cf6351
d0ac527
 
68aa6fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import gradio as gr
import torch
# 导入 vllm 相关的库
from vllm import LLM, SamplingParams
# 保持 AutoTokenizer 用于处理聊天模板
from transformers import AutoTokenizer
from modelscope.hub.snapshot_download import snapshot_download
import os
import traceback

# --- 关键配置 ---

# !!! (重要) 请在此处填入您在 ModelScope 上的 "已合并" 模型的 ID
# vLLM 将直接加载这个完整的模型
#
# 我使用了您原始的基础模型 ID 作为示例,但您*必须*将其替换为
# 已经合并了 LoRA (UniVectorSQL) 的那个模型的 ID。
MERGED_MODEL_ID = "zrwang/UniVectorSQL-7B-LoRA-merged" 
# 示例:如果您的合并后模型叫 "risemds/UniVectorSQL-7B-Merged",请修改上面一行

# -----------------


# --- vLLM 模型加载 ---

# 1. 下载模型
# 我们将忽略 .md 文件和 training_args.bin 文件
print(f"开始下载已合并的模型: {MERGED_MODEL_ID}")
model_dir = snapshot_download(
    MERGED_MODEL_ID,
    revision='master',
    ignore_patterns=["*.md", "training_args.bin", "checkpoint-*"]  
)
print(f"模型下载完成,路径: {model_dir}")


# 2. 加载 Tokenizer
# 我们需要 Tokenizer 来应用聊天模板
try:
    tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
    print("从模型目录加载 Tokenizer 成功。")
except Exception as e:
    print(f"从模型目录加载 Tokenizer 失败: {e}")
    # 如果失败,程序无法继续,因为 vLLM 需要知道 EOT token
    raise e

# --- (保留) 修复 Tokenizer 的 pad_token_id ---
# Qwen1.5 基础模型可能没有设置 pad_token_id
if tokenizer.pad_token_id is None:
    print("Tokenizer 未设置 pad_token_id,将其设置为 eos_token_id。")
    tokenizer.pad_token_id = tokenizer.eos_token_id
# (注意:我们不再需要修复 model.config,因为 vLLM 会处理)
# ----------------------------------------


# 3. 加载 vLLM 模型
print(f"开始加载 vLLM 模型: {model_dir}")

# 自动检测 GPU 数量以设置 tensor_parallel_size
if torch.cuda.is_available():
    gpu_count = torch.cuda.device_count()
    print(f"检测到 {gpu_count} 个 CUDA GPU。")
else:
    print("警告: 未检测到 CUDA GPU。vLLM 强烈建议在 GPU 上运行。")
    gpu_count = 1 # 假设至少为 1,vLLM 0.4.0+ 也支持 CPU (但很慢)

llm = LLM(
    model=model_dir,
    trust_remote_code=True,
    tensor_parallel_size=gpu_count, # 自动使用所有可用的 GPU
    dtype="auto"                    # vLLM 会自动选择 (例如 bfloat16 或 float16)
    # max_model_len=4096            # (可选) 如果需要,设置最大上下文长度
)
print("vLLM 模型加载完成。")


# 4. 定义推理函数 (*** 已修改为使用 vLLM ***)

# --- 提前准备 EOT token 和 SamplingParams ---
# 1. 获取 EOT (End of Text) token IDs (用于停止生成)
# Qwen 系列使用 <|im_end|> (151645) 和/或 <|endoftext|> (151643)
eot_ids = [tokenizer.eos_token_id]

im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
if im_end_token_id != tokenizer.unk_token_id and im_end_token_id not in eot_ids:
    eot_ids.append(im_end_token_id)

print(f"vLLM 将使用 stop_token_ids: {eot_ids} (eos_token_id: {tokenizer.eos_token_id})")

# 2. 创建 vLLM SamplingParams
# (原 HFace generate 参数:max_new_tokens=2048, eos_token_id=eot_ids)
sampling_params = SamplingParams(
    max_tokens=2048,        # 对应 HFace 的 max_new_tokens
    stop_token_ids=eot_ids, # <--- 关键:告诉 vLLM 何时停止
    temperature=0.0,        # 对于 SQL 生成,使用贪婪采样 (0.0) 通常是最好的
    top_p=1.0,              # (配合 temperature=0.0)
)
# ----------------------------------------

def inference(text_input):
    print(f"收到输入: {text_input}")
    try:
        # 1. 构建 Qwen (ChatML) 对话格式
        messages = [
            {"role": "user", "content": text_input}
        ]
        
        # 2. 应用对话模板
        # vLLM 需要的是 *字符串*,而不是 token IDs
        # tokenize=False 会返回格式化后的字符串
        # add_generation_prompt=True 会在末尾添加 <|im_start|>assistant\n
        prompt_str = tokenizer.apply_chat_template(
            messages,
            tokenize=False,             # <--- 关键:返回字符串
            add_generation_prompt=True  # <--- 关键:添加 assistant 提示
        )
        
        # 3. vLLM 生成
        # llm.generate 接受一个 prompt 列表
        outputs = llm.generate([prompt_str], sampling_params)
        
        # 4. 提取结果
        # outputs 是一个列表,对应输入的 prompt 列表
        # outputs[0] 是第一个 prompt 的 RequestOutput
        # outputs[0].outputs[0] 是第一个 (best_of=1) 的生成结果
        # .text 包含 *新生成* 的文本 (不包含 prompt)
        result = outputs[0].outputs[0].text.strip()

        print(f"生成结果 (仅新内容): {result}")
        return result
        
    except Exception as e:
        print(f"vLLM 推理时出错: {e}")
        # 打印更详细的 vLLM 错误
        traceback.print_exc()
        return f"错误: {e}"

# ----------------------------------------------------

# (示例保持不变)
example = """You are a senior SQL engineer. Your task is to generate a single, correct, and executable SQL query to answer the user's question based on the provided database context.
... (示例内容和之前一样) ...
Let's think step by step!
"""
# (为了简洁,省略了示例的完整内容,使用您原始的 'example' 变量即可)
examples = [example] 


# 6. 创建 Gradio 界面 (保持不变)
iface = gr.Interface(
    fn=inference,  
    inputs=gr.Textbox(lines=10, label="输入查询 (Input Query)"),  
    outputs=gr.Textbox(lines=10, label="模型输出 (Model Output)"),
    title="UniVectorSQL-7B (vLLM 推理)",
    # description="这是一个 Text-to-SQL 模型。请输入您的问题 (Question) 和数据库模式 (Schema)。点击下方示例尝试。",
    examples=examples 
)
# ---------------------

print("启动 Gradio 界面...")
# 在 Hugging Face 或 ModelScope Space 中,share=True 不是必需的
iface.launch()