dongwenyao commited on
Commit
b824726
·
1 Parent(s): d0ac527

vllm inference

Browse files
Files changed (2) hide show
  1. app.py +99 -303
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,363 +1,159 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from peft import PeftModel
 
 
5
  from modelscope.hub.snapshot_download import snapshot_download
6
  import os
7
-
8
- # --- 自动检测设备 ---
9
- # 检查是否有可用的 GPU (NVIDIA CUDA)
10
- if torch.cuda.is_available():
11
- device = "cuda"
12
- torch_dtype = torch.float16 # GPU 使用 float16
13
- device_map = "auto" # 自动分配到 GPU
14
- print("检测到 CUDA (GPU),将使用 GPU。")
15
- # 检查是否有可用的 Apple Silicon (MPS)
16
- elif torch.backends.mps.is_available():
17
- device = "mps"
18
- torch_dtype = torch.float16 # Apple Silicon GPU 也可以用 float16
19
- device_map = "mps" # 明确指定 mps
20
- print("检测到 MPS (Apple Silicon),将使用 MPS。")
21
- # 否则回退到 CPU
22
- else:
23
- device = "cpu"
24
- torch_dtype = torch.float32 # CPU 使用 float32 (float16 在 CPU 上很慢或不支持)
25
- device_map = "cpu" # 明确指定 cpu
26
- print("未检测到 GPU,将使用 CPU。")
27
- # ---------------------
28
-
29
 
30
  # --- 关键配置 ---
31
 
32
- # 1. 您的 LoRA 模型 ID
33
- # LORA_MODEL_ID = "/mnt/DataFlow/ydw/model/risemds/UniVectorSQL-7B-LoRA-all_steps_1030_new_data/"
34
- # lora_model_dir = LORA_MODEL_ID
35
- LORA_MODEL_ID = "risemds/UniVectorSQL-7B-LoRA-all_steps_1030_new_data"
36
-
37
- # 2. 基础模型 ID (!!!)
38
- # 您必须找到这个 LoRA 对应的基础模型是什么。
39
- # BASE_MODEL_ID = "/mnt/DataFlow/ydw/model/seeklhy/OmniSQL-7B/"
40
- BASE_MODEL_ID = "seeklhy/OmniSQL-7B"
41
 
42
  # -----------------
43
 
44
 
45
- # 使用 ignore_patterns 避免下载不需要的 checkpoint/文件
 
 
46
  # 我们将忽略 .md 文件和 training_args.bin 文件
47
- print(f"开始下载 LoRA 适配器: {LORA_MODEL_ID}")
48
- lora_model_dir = snapshot_download(
49
- LORA_MODEL_ID,
50
  revision='master',
51
- ignore_patterns=["*.md", "training_args.bin", "checkpoint-*"]
52
  )
53
- print(f"LoRA 适配器下载完成,路径: {lora_model_dir}")
54
 
55
- # 1. 加载 Tokenizer
56
- # 优先使用 LoRA 仓库中的 Tokenizer,因为它可能已更新(例如添加了新 token)
 
57
  try:
58
- tokenizer = AutoTokenizer.from_pretrained(lora_model_dir, trust_remote_code=True)
59
- print(" LoRA 目录加载 Tokenizer 成功。")
60
  except Exception as e:
61
- print(f" LoRA 目录加载 Tokenizer 失败 ({e}), 尝试从基础模型加载。")
62
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
63
-
64
 
65
- # 2. 加载基础模型
66
- print(f"开始加载基础模型: {BASE_MODEL_ID}")
67
- model = AutoModelForCausalLM.from_pretrained(
68
- BASE_MODEL_ID,
69
- device_map=device_map, # <--- 修改:使用自动检测的 device_map
70
- torch_dtype=torch_dtype, # <--- 修改:使用自动检测的 torch_dtype
71
- trust_remote_code=True
72
- )
73
- print("基础模型加载完成。")
74
-
75
- # --- (新增) 修复 Qwen 模型的 pad_token_id ---
76
- # Qwen1.5 基础模型可能没有设置 pad_token_id,这会在推理时产生警告
77
- # 我们将其设置为 eos_token_id
78
  if tokenizer.pad_token_id is None:
79
  print("Tokenizer 未设置 pad_token_id,将其设置为 eos_token_id。")
80
  tokenizer.pad_token_id = tokenizer.eos_token_id
81
-
82
- if model.config.pad_token_id is None:
83
- print("Model config 未设置 pad_token_id,将其设置为 eos_token_id。")
84
- model.config.pad_token_id = tokenizer.eos_token_id
85
  # ----------------------------------------
86
 
87
 
88
- # 3. 加载并融合 LoRA 适配器
89
- print(f"开始加载 LoRA 适配器到基础模型...")
90
- model = PeftModel.from_pretrained(
91
- model,
92
- lora_model_dir,
93
- device_map=device_map # <--- 修改:传递自动检测的 device_map
 
 
 
 
 
 
 
 
 
 
 
94
  )
95
- print("LoRA 适配器加载成功。")
96
 
97
- # (可选) 合并权重以加快推理速度,但这会占用更多内存
98
- # 注意:在 CPU 上合并可能非常慢
99
- print("正在合并 LoRA 权重...")
100
- model = model.merge_and_unload()
101
- print("权重合并完成。")
102
 
103
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- # 4. 定义推理函数 (*** 已修改为使用 Qwen 对话模板 ***)
106
  def inference(text_input):
107
  print(f"收到输入: {text_input}")
108
  try:
109
  # 1. 构建 Qwen (ChatML) 对话格式
110
- # 假设 text_input 是完整的用户提示 (包含指令、Schema等)
111
  messages = [
112
  {"role": "user", "content": text_input}
113
  ]
114
 
115
  # 2. 应用对话模板
116
- # tokenizer.apply_chat_template 会自动处理特殊 tokens (例如 <|im_start|>)
117
- # 并且 add_generation_prompt=True 会添加 <|im_start|>assistant\n
118
- # 这会告诉模型开始生成回复
119
- inputs = tokenizer.apply_chat_template(
120
  messages,
121
- return_tensors="pt",
122
- add_generation_prompt=True
123
- ).to(model.device)
124
-
125
- # 3. 获取输入 token 的长度
126
- # 这里的 inputs 是 token IDs tensor, shape [1, sequence_length]
127
- input_len = inputs.shape[1]
128
-
129
- # 4. 获取 EOT (End of Text) token IDs
130
- # Qwen 系列使用 <|im_end|> (151645) 和/或 <|endoftext|> (151643)
131
- # tokenizer.eos_token_id 应该已经正确设置 (来自 Qwen1.5 基础模型)
132
- eot_ids = [tokenizer.eos_token_id]
133
-
134
- # 额外检查 <|im_end|>,确保它也在停止列表中
135
- im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
136
- if im_end_token_id != tokenizer.unk_token_id and im_end_token_id not in eot_ids:
137
- eot_ids.append(im_end_token_id)
138
-
139
- print(f"使用 EOT token IDs: {eot_ids} (eos_token_id: {tokenizer.eos_token_id})")
140
-
141
- # 5. 生成
142
- # 使用 eos_token_id 列表来确保模型在 <|im_end|> 处停止
143
- outputs = model.generate(
144
- input_ids=inputs,
145
- max_new_tokens=2048,
146
- eos_token_id=eot_ids, # <--- 关键:告诉模型何时停止
147
- pad_token_id=tokenizer.eos_token_id # 避免 HF warning
148
  )
149
 
150
- # 6. 从输出的 token 序列中,只选择新生成的部分
151
- # outputs[0] 包含 (输入 + 生成) 的所有 tokens
152
- new_tokens = outputs[0][input_len:]
153
 
154
- # 7. 只解码新生成的 token
155
- # skip_special_tokens=True 会移除任何 <|im_end|>
156
- result = tokenizer.decode(new_tokens, skip_special_tokens=True)
 
 
 
157
 
158
  print(f"生成结果 (仅新内容): {result}")
159
  return result
 
160
  except Exception as e:
161
- print(f"推理时出错: {e}")
 
 
162
  return f"错误: {e}"
163
 
164
  # ----------------------------------------------------
165
 
 
166
  example = """You are a senior SQL engineer. Your task is to generate a single, correct, and executable SQL query to answer the user's question based on the provided database context.
167
-
168
- ## INSTRUCTIONS
169
- 1. **Backend Adherence**: The query MUST be written for the `clickhouse` database backend. This is a strict requirement.
170
- 2. **Follow Special Notes**: You MUST strictly follow all syntax, functions, or constraints described in the [Database Backend Notes]. Pay extremely close attention to this section, as it contains critical, non-standard rules.
171
- 3. **Schema Integrity**: The query MUST ONLY use the tables and columns provided in the [Database Schema]. Do not invent or guess table or column names.
172
- 4. **Answer the Question**: The query must directly and accurately answer the [Natural Language Question].
173
- 5. **Output Format**: Enclose the final SQL query in a single Markdown code block formatted for SQL (` ```sql ... ``` `).
174
- 6. **Embedding Match**: If the [EMBEDDING_MODEL_NAME] parameter is a valid string (e.g., 'intfloat/E5-Mistral-7B-Instruct'), you MUST generate a query that includes the WHERE [EMBEDDING_COLUMN_NAME] MATCH lembed(...) clause for vector search. Otherwise, if embedding model name below the [EMBEDDING MODEL NAME] is None, , you MUST generate a standard SQL query that OMITS the entire MATCH lembed(...) clause. The query should not perform any vector search.
175
- 7. **Embedding Name**: If a value is provided for the parameter `[EMBEDDING_MODEL_NAME]`, your generated query must contain a `lembed` function call. The first parameter to the `lembed` function MUST be the exact value of `[EMBEDDING_MODEL_NAME]`, formatted as a string literal (enclosed in single quotes). For example, if `[EMBEDDING_MODEL_NAME]` is `laion/CLIP-ViT-B-32-laion2B-s34B-b79K`, the generated SQL must include `MATCH lembed('laion/CLIP-ViT-B-32-laion2B-s34B-b79K', ...)`.
176
-
177
- ## DATABASE CONTEXT
178
-
179
- [DATABASE BACKEND]:
180
- clickhouse
181
-
182
- [DATABASE SCHEMA]:
183
- CREATE TABLE CAMPAIGN_RESULTS (
184
- `result_id` Nullable(Int64),
185
- `campaign_id` Nullable(Int64),
186
- `territory_id` Nullable(Int64),
187
- `menu_item_id` Nullable(Int64),
188
- `sales_increase_percentage` Nullable(Float64),
189
- `customer_engagement_score` Nullable(Float64),
190
- `feedback_improvement` Nullable(Float64)
191
- );
192
- CREATE TABLE CUSTOMERS (
193
- `customer_id` Nullable(Int64),
194
- `customer_name` Nullable(String),
195
- `email` Nullable(String),
196
- `phone_number` Nullable(String),
197
- `loyalty_points` Nullable(Int64)
198
- );
199
- CREATE TABLE CUSTOMER_FEEDBACK (
200
- `feedback_id` Nullable(Int64),
201
- `menu_item_id` Nullable(Int64),
202
- `customer_id` Nullable(Int64),
203
- `feedback_date` Nullable(String),
204
- `rating` Nullable(Float64),
205
- `comments` Nullable(String),
206
- `feedback_type` Nullable(String),
207
- `comments_embedding` Array(Float32)
208
- );
209
- CREATE TABLE INGREDIENTS (
210
- `ingredient_id` Nullable(Int64),
211
- `name` Nullable(String),
212
- `description` Nullable(String),
213
- `supplier_id` Nullable(Int64),
214
- `cost_per_unit` Nullable(Float64),
215
- `description_embedding` Array(Float32)
216
- );
217
- CREATE TABLE MARKETING_CAMPAIGNS (
218
- `campaign_id` Nullable(Int64),
219
- `campaign_name` Nullable(String),
220
- `start_date` Nullable(String),
221
- `end_date` Nullable(String),
222
- `budget` Nullable(Float64),
223
- `objective` Nullable(String),
224
- `territory_id` Nullable(Int64)
225
- );
226
- CREATE TABLE MENU_CATEGORIES (
227
- `category_id` Nullable(Int64),
228
- `category_name` Nullable(String),
229
- `description` Nullable(String),
230
- `parent_category_id` Nullable(Int64),
231
- `description_embedding` Array(Float32)
232
- );
233
- CREATE TABLE MENU_ITEMS (
234
- `menu_item_id` Nullable(Int64),
235
- `territory_id` Nullable(Int64),
236
- `name` Nullable(String),
237
- `price_inr` Nullable(Float64),
238
- `price_usd` Nullable(Float64),
239
- `price_eur` Nullable(Float64),
240
- `category_id` Nullable(Int64),
241
- `menu_type` Nullable(String),
242
- `calories` Nullable(Int64),
243
- `is_vegetarian` Nullable(Int64),
244
- `promotion_id` Nullable(Int64)
245
- );
246
- CREATE TABLE MENU_ITEM_INGREDIENTS (
247
- `menu_item_id` Nullable(Int64),
248
- `ingredient_id` Nullable(Int64),
249
- `quantity` Nullable(Float64),
250
- `unit_of_measurement` Nullable(String)
251
- );
252
- CREATE TABLE PRICING_STRATEGIES (
253
- `strategy_id` Nullable(Int64),
254
- `strategy_name` Nullable(String),
255
- `description` Nullable(String),
256
- `territory_id` Nullable(Int64),
257
- `effective_date` Nullable(String),
258
- `end_date` Nullable(String),
259
- `description_embedding` Array(Float32)
260
- );
261
- CREATE TABLE PROMOTIONS (
262
- `promotion_id` Nullable(Int64),
263
- `promotion_name` Nullable(String),
264
- `start_date` Nullable(String),
265
- `end_date` Nullable(String),
266
- `discount_percentage` Nullable(Float64),
267
- `category_id` Nullable(Int64),
268
- `territory_id` Nullable(Int64)
269
- );
270
- CREATE TABLE SALES_DATA (
271
- `sale_id` Nullable(Int64),
272
- `menu_item_id` Nullable(Int64),
273
- `territory_id` Nullable(Int64),
274
- `sale_date` Nullable(String),
275
- `quantity_sold` Nullable(Int64),
276
- `total_revenue` Nullable(Float64),
277
- `discount_applied` Nullable(Float64),
278
- `customer_id` Nullable(Int64)
279
- );
280
- CREATE TABLE SALES_FORECAST (
281
- `forecast_id` Nullable(Int64),
282
- `menu_item_id` Nullable(Int64),
283
- `territory_id` Nullable(Int64),
284
- `forecast_date` Nullable(String),
285
- `forecast_quantity` Nullable(Int64),
286
- `forecast_revenue` Nullable(Float64),
287
- `prediction_accuracy` Nullable(Float64)
288
- );
289
- CREATE TABLE SUPPLIERS (
290
- `supplier_id` Nullable(Int64),
291
- `supplier_name` Nullable(String),
292
- `contact_email` Nullable(String),
293
- `phone_number` Nullable(String),
294
- `address` Nullable(String)
295
- );
296
- CREATE TABLE TERRITORIES (
297
- `territory_id` Nullable(Int64),
298
- `territory_name` Nullable(String),
299
- `region` Nullable(String),
300
- `contact_email` Nullable(String),
301
- `local_tax_rate` Nullable(Float64),
302
- `currency_code` Nullable(String)
303
- );
304
- CREATE TABLE USERS (
305
- `user_id` Nullable(Int64),
306
- `user_name` Nullable(String),
307
- `email` Nullable(String),
308
- `role_id` Nullable(Int64),
309
- `territory_id` Nullable(Int64)
310
- );
311
- CREATE TABLE USER_ROLES (
312
- `role_id` Nullable(Int64),
313
- `role_name` Nullable(String),
314
- `description` Nullable(String),
315
- `permissions` Nullable(String),
316
- `description_embedding` Array(Float32)
317
- );
318
-
319
- [DATABASE BACKEND NOTES]:
320
- There are a few requirements you should comply with in addition:
321
- 1. When generating SQL queries, you should prioritize utilizing K-Nearest Neighbor (KNN) searches whenever contextually appropriate. However, you must avoid unnecessary/forced KNN implementations for:
322
- -- Traditional relational data queries (especially for columns like: id, age, price).
323
- -- Cases where standard SQL operators (equality, range, or aggregation functions) are more efficient and semantically appropriate.
324
- 2. Only columns with a vector type (like: Array(Float32)) support KNN queries. The names of these vector columns often end with "_embedding". You can perform KNN searches when the column name you need to query ends with "_embedding" or is otherwise identified as a vector column.
325
- 3. In ClickHouse, vector similarity search is performed using distance functions. You must explicitly calculate the distance in the SELECT clause using a function like L2Distance and give it an alias, typically "AS distance". This distance alias will not be implicitly generated.
326
- 4. The lembed function is used to transform a string into a semantic vector. This function should be used within a WITH clause to define the reference vector. The lembed function has two parameters: the first is the name of the embedding model used (default value: 'intfloat/E5-Mistral-7B-Instruct'), and the second is the string content to embed. The resulting vector should be given an alias in the WITH clause.
327
- 5. You must generate plausible and semantically relevant words or sentences for the second parameter of the lembed function based on the column's name, type, and comment. For example, if a column is named product_description_embedding and its comment is "Embedding of the product's features and marketing text", you could generate text like "durable and waterproof outdoor adventure camera".
328
- 6. Every KNN search query MUST conclude with "ORDER BY distance LIMIT N" to retrieve the top-N most similar results. The LIMIT clause is mandatory for performing a KNN search and ensuring predictable performance.
329
- 7. When combining a vector search with JOIN operations, the standard WHERE clause should be used to apply filters from any of the joined tables. The ORDER BY distance LIMIT N clause is applied after all filtering and joins are resolved.
330
- 8. A SELECT statement should typically be ordered by a single distance calculation to perform one primary KNN search. However, subqueries can perform their own independent KNN searches, each with its own WITH clause, distance calculation, and ORDER BY distance LIMIT N clause.
331
-
332
- ## Example of a ClickHouse KNN Query
333
- DB Schema: Some table on articles with a column content_embedding Array(Float32).
334
- Query Task: Identify the article ID of the single most relevant article discussing innovative algorithms in graph theory.
335
- Generated SQL:
336
- ```sql
337
- WITH\n lembed('intfloat/E5-Mistral-7B-Instruct', 'innovative algorithms in graph theory.') AS ref_vec_0\n\nSELECT id, L2Distance(articles.abstract_embedding, ref_vec_0) AS distance\nFROM articles\nORDER BY distance\nLIMIT 1;
338
- ```
339
-
340
- [EMBEDDING MODEL NAME]:
341
- intfloat/E5-Mistral-7B-Instruct
342
-
343
- ## NATURAL LANGUAGE QUESTION
344
- In vector searches, the `MATCH` operator performs an approximate nearest neighbor (ANN) search, which identifies items based on their similarity to a given vector. The `lembed()` function is used to convert text phrases into vector representations using a specific model, in this case, `intfloat/E5-Mistral-7B-Instruct`. This helps in finding items that align closely with the concept of "Popular menu items based on sales." The parameter `k = 5` specifies that only the top 5 categories, which are most similar in terms of the embedding, should be considered. The similarity is determined by calculating the Euclidean distance between vectors, where a smaller distance indicates higher similarity.
345
- Can you unveil the crown jewel of our vegetarian delights, the one that has soared to the top of the sales charts from the elite circle of our most cherished categories this year?
346
-
347
  Let's think step by step!
348
  """
 
 
349
 
350
- examples = [example]
351
 
352
- # 6. 创建 Gradio 界面 (原第5步)
353
  iface = gr.Interface(
354
- fn=inference,
355
- # 增加了行数以便容纳 Schema
356
- inputs=gr.Textbox(lines=10, label="输入查询 (Input Query)"),
357
  outputs=gr.Textbox(lines=10, label="模型输出 (Model Output)"),
358
- title="UniVectorSQL-7B-LoRA 推理",
359
  # description="这是一个 Text-to-SQL 模型。请输入您的问题 (Question) 和数据库模式 (Schema)。点击下方示例尝试。",
360
- examples=examples # <--- 添加示例
361
  )
362
  # ---------------------
363
 
 
1
  import gradio as gr
2
  import torch
3
+ # 导入 vllm 相关的库
4
+ from vllm import LLM, SamplingParams
5
+ # 保持 AutoTokenizer 用于处理聊天模板
6
+ from transformers import AutoTokenizer
7
  from modelscope.hub.snapshot_download import snapshot_download
8
  import os
9
+ import traceback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # --- 关键配置 ---
12
 
13
+ # !!! (重要) 请在此处填入您在 ModelScope 上的 "已合并" 模型的 ID
14
+ # vLLM 将直接加载这个完整的模型
15
+ #
16
+ # 我使用了您原始的基础模型 ID 作为示例,但您*必须*将其替换为
17
+ # 已经合并了 LoRA (UniVectorSQL) 的那个模型的 ID。
18
+ MERGED_MODEL_ID = "zrwang/UniVectorSQL-7B-LoRA-merged"
19
+ # 示例:如果您的合并后模型叫 "risemds/UniVectorSQL-7B-Merged",请修改上面一行
 
 
20
 
21
  # -----------------
22
 
23
 
24
+ # --- vLLM 模型加载 ---
25
+
26
+ # 1. 下载模型
27
  # 我们将忽略 .md 文件和 training_args.bin 文件
28
+ print(f"开始下载已合并的模型: {MERGED_MODEL_ID}")
29
+ model_dir = snapshot_download(
30
+ MERGED_MODEL_ID,
31
  revision='master',
32
+ ignore_patterns=["*.md", "training_args.bin", "checkpoint-*"]
33
  )
34
+ print(f"模型下载完成,路径: {model_dir}")
35
 
36
+
37
+ # 2. 加载 Tokenizer
38
+ # 我们需要 Tokenizer 来应用聊天模板
39
  try:
40
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
41
+ print("从模型目录加载 Tokenizer 成功。")
42
  except Exception as e:
43
+ print(f"从模型目录加载 Tokenizer 失败: {e}")
44
+ # 如果失败,程序无法继续,因为 vLLM 需要知道 EOT token
45
+ raise e
46
 
47
+ # --- (保留) 修复 Tokenizer 的 pad_token_id ---
48
+ # Qwen1.5 基础模型可能没有设置 pad_token_id
 
 
 
 
 
 
 
 
 
 
 
49
  if tokenizer.pad_token_id is None:
50
  print("Tokenizer 未设置 pad_token_id,将其设置为 eos_token_id。")
51
  tokenizer.pad_token_id = tokenizer.eos_token_id
52
+ # (注意:我们不再需要修复 model.config,因为 vLLM 会处理)
 
 
 
53
  # ----------------------------------------
54
 
55
 
56
+ # 3. 加载 vLLM 模型
57
+ print(f"开始加载 vLLM 模型: {model_dir}")
58
+
59
+ # 自动检测 GPU 数量以设置 tensor_parallel_size
60
+ if torch.cuda.is_available():
61
+ gpu_count = torch.cuda.device_count()
62
+ print(f"检测到 {gpu_count} 个 CUDA GPU。")
63
+ else:
64
+ print("警告: 未检测到 CUDA GPU。vLLM 强烈建议在 GPU 上运行。")
65
+ gpu_count = 1 # 假设至少为 1,vLLM 0.4.0+ 也支持 CPU (但很慢)
66
+
67
+ llm = LLM(
68
+ model=model_dir,
69
+ trust_remote_code=True,
70
+ tensor_parallel_size=gpu_count, # 自动使用所有可用的 GPU
71
+ dtype="auto" # vLLM 会自动选择 (例如 bfloat16 或 float16)
72
+ # max_model_len=4096 # (可选) 如果需要,设置最大上下文长度
73
  )
74
+ print("vLLM 模型加载完成。")
75
 
 
 
 
 
 
76
 
77
+ # 4. 定义推理函数 (*** 已修改为使用 vLLM ***)
78
+
79
+ # --- 提前准备 EOT token 和 SamplingParams ---
80
+ # 1. 获取 EOT (End of Text) token IDs (用于停止生成)
81
+ # Qwen 系列使用 <|im_end|> (151645) 和/或 <|endoftext|> (151643)
82
+ eot_ids = [tokenizer.eos_token_id]
83
+
84
+ im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
85
+ if im_end_token_id != tokenizer.unk_token_id and im_end_token_id not in eot_ids:
86
+ eot_ids.append(im_end_token_id)
87
+
88
+ print(f"vLLM 将使用 stop_token_ids: {eot_ids} (eos_token_id: {tokenizer.eos_token_id})")
89
+
90
+ # 2. 创建 vLLM SamplingParams
91
+ # (原 HFace generate 参数:max_new_tokens=2048, eos_token_id=eot_ids)
92
+ sampling_params = SamplingParams(
93
+ max_tokens=2048, # 对应 HFace 的 max_new_tokens
94
+ stop_token_ids=eot_ids, # <--- 关键:告诉 vLLM 何时停止
95
+ temperature=0.0, # 对于 SQL 生成,使用贪婪采样 (0.0) 通常是最好的
96
+ top_p=1.0, # (配合 temperature=0.0)
97
+ )
98
+ # ----------------------------------------
99
 
 
100
  def inference(text_input):
101
  print(f"收到输入: {text_input}")
102
  try:
103
  # 1. 构建 Qwen (ChatML) 对话格式
 
104
  messages = [
105
  {"role": "user", "content": text_input}
106
  ]
107
 
108
  # 2. 应用对话模板
109
+ # vLLM 需要的是 *字符串*,而不是 token IDs
110
+ # tokenize=False 会返回格式化后的字符串
111
+ # add_generation_prompt=True 会在末尾添加 <|im_start|>assistant\n
112
+ prompt_str = tokenizer.apply_chat_template(
113
  messages,
114
+ tokenize=False, # <--- 关键:返回字符串
115
+ add_generation_prompt=True # <--- 关键:添加 assistant 提示
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  )
117
 
118
+ # 3. vLLM 生成
119
+ # llm.generate 接受一个 prompt 列表
120
+ outputs = llm.generate([prompt_str], sampling_params)
121
 
122
+ # 4. 提取结果
123
+ # outputs 是一个列表,对应输入的 prompt 列表
124
+ # outputs[0] 是第一个 prompt 的 RequestOutput
125
+ # outputs[0].outputs[0] 是第一个 (best_of=1) 的生成结果
126
+ # .text 包含 *新生成* 的文本 (不包含 prompt)
127
+ result = outputs[0].outputs[0].text.strip()
128
 
129
  print(f"生成结果 (仅新内容): {result}")
130
  return result
131
+
132
  except Exception as e:
133
+ print(f"vLLM 推理时出错: {e}")
134
+ # 打印更详细的 vLLM 错误
135
+ traceback.print_exc()
136
  return f"错误: {e}"
137
 
138
  # ----------------------------------------------------
139
 
140
+ # (示例保持不变)
141
  example = """You are a senior SQL engineer. Your task is to generate a single, correct, and executable SQL query to answer the user's question based on the provided database context.
142
+ ... (示例内容和之前一样) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  Let's think step by step!
144
  """
145
+ # (为了简洁,省略了示例的完整内容,使用您原始的 'example' 变量即可)
146
+ examples = [example]
147
 
 
148
 
149
+ # 6. 创建 Gradio 界面 (保持不变)
150
  iface = gr.Interface(
151
+ fn=inference,
152
+ inputs=gr.Textbox(lines=10, label="输入查询 (Input Query)"),
 
153
  outputs=gr.Textbox(lines=10, label="模型输出 (Model Output)"),
154
+ title="UniVectorSQL-7B (vLLM 推理)",
155
  # description="这是一个 Text-to-SQL 模型。请输入您的问题 (Question) 和数据库模式 (Schema)。点击下方示例尝试。",
156
+ examples=examples
157
  )
158
  # ---------------------
159
 
requirements.txt CHANGED
@@ -2,4 +2,5 @@ torch
2
  transformers
3
  peft
4
  modelscope
5
- gradio
 
 
2
  transformers
3
  peft
4
  modelscope
5
+ gradio
6
+ vllm