Spaces:
Runtime error
Runtime error
Commit
·
b824726
1
Parent(s):
d0ac527
vllm inference
Browse files- app.py +99 -303
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -1,363 +1,159 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
-
|
| 4 |
-
from
|
|
|
|
|
|
|
| 5 |
from modelscope.hub.snapshot_download import snapshot_download
|
| 6 |
import os
|
| 7 |
-
|
| 8 |
-
# --- 自动检测设备 ---
|
| 9 |
-
# 检查是否有可用的 GPU (NVIDIA CUDA)
|
| 10 |
-
if torch.cuda.is_available():
|
| 11 |
-
device = "cuda"
|
| 12 |
-
torch_dtype = torch.float16 # GPU 使用 float16
|
| 13 |
-
device_map = "auto" # 自动分配到 GPU
|
| 14 |
-
print("检测到 CUDA (GPU),将使用 GPU。")
|
| 15 |
-
# 检查是否有可用的 Apple Silicon (MPS)
|
| 16 |
-
elif torch.backends.mps.is_available():
|
| 17 |
-
device = "mps"
|
| 18 |
-
torch_dtype = torch.float16 # Apple Silicon GPU 也可以用 float16
|
| 19 |
-
device_map = "mps" # 明确指定 mps
|
| 20 |
-
print("检测到 MPS (Apple Silicon),将使用 MPS。")
|
| 21 |
-
# 否则回退到 CPU
|
| 22 |
-
else:
|
| 23 |
-
device = "cpu"
|
| 24 |
-
torch_dtype = torch.float32 # CPU 使用 float32 (float16 在 CPU 上很慢或不支持)
|
| 25 |
-
device_map = "cpu" # 明确指定 cpu
|
| 26 |
-
print("未检测到 GPU,将使用 CPU。")
|
| 27 |
-
# ---------------------
|
| 28 |
-
|
| 29 |
|
| 30 |
# --- 关键配置 ---
|
| 31 |
|
| 32 |
-
#
|
| 33 |
-
#
|
| 34 |
-
#
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
#
|
| 39 |
-
# BASE_MODEL_ID = "/mnt/DataFlow/ydw/model/seeklhy/OmniSQL-7B/"
|
| 40 |
-
BASE_MODEL_ID = "seeklhy/OmniSQL-7B"
|
| 41 |
|
| 42 |
# -----------------
|
| 43 |
|
| 44 |
|
| 45 |
-
#
|
|
|
|
|
|
|
| 46 |
# 我们将忽略 .md 文件和 training_args.bin 文件
|
| 47 |
-
print(f"
|
| 48 |
-
|
| 49 |
-
|
| 50 |
revision='master',
|
| 51 |
-
ignore_patterns=["*.md", "training_args.bin", "checkpoint-*"]
|
| 52 |
)
|
| 53 |
-
print(f"
|
| 54 |
|
| 55 |
-
|
| 56 |
-
#
|
|
|
|
| 57 |
try:
|
| 58 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 59 |
-
print("
|
| 60 |
except Exception as e:
|
| 61 |
-
print(f"
|
| 62 |
-
|
| 63 |
-
|
| 64 |
|
| 65 |
-
#
|
| 66 |
-
|
| 67 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 68 |
-
BASE_MODEL_ID,
|
| 69 |
-
device_map=device_map, # <--- 修改:使用自动检测的 device_map
|
| 70 |
-
torch_dtype=torch_dtype, # <--- 修改:使用自动检测的 torch_dtype
|
| 71 |
-
trust_remote_code=True
|
| 72 |
-
)
|
| 73 |
-
print("基础模型加载完成。")
|
| 74 |
-
|
| 75 |
-
# --- (新增) 修复 Qwen 模型的 pad_token_id ---
|
| 76 |
-
# Qwen1.5 基础模型可能没有设置 pad_token_id,这会在推理时产生警告
|
| 77 |
-
# 我们将其设置为 eos_token_id
|
| 78 |
if tokenizer.pad_token_id is None:
|
| 79 |
print("Tokenizer 未设置 pad_token_id,将其设置为 eos_token_id。")
|
| 80 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 81 |
-
|
| 82 |
-
if model.config.pad_token_id is None:
|
| 83 |
-
print("Model config 未设置 pad_token_id,将其设置为 eos_token_id。")
|
| 84 |
-
model.config.pad_token_id = tokenizer.eos_token_id
|
| 85 |
# ----------------------------------------
|
| 86 |
|
| 87 |
|
| 88 |
-
# 3.
|
| 89 |
-
print(f"开始加载
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
)
|
| 95 |
-
print("
|
| 96 |
|
| 97 |
-
# (可选) 合并权重以加快推理速度,但这会占用更多内存
|
| 98 |
-
# 注意:在 CPU 上合并可能非常慢
|
| 99 |
-
print("正在合并 LoRA 权重...")
|
| 100 |
-
model = model.merge_and_unload()
|
| 101 |
-
print("权重合并完成。")
|
| 102 |
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
-
# 4. 定义推理函数 (*** 已修改为使用 Qwen 对话模板 ***)
|
| 106 |
def inference(text_input):
|
| 107 |
print(f"收到输入: {text_input}")
|
| 108 |
try:
|
| 109 |
# 1. 构建 Qwen (ChatML) 对话格式
|
| 110 |
-
# 假设 text_input 是完整的用户提示 (包含指令、Schema等)
|
| 111 |
messages = [
|
| 112 |
{"role": "user", "content": text_input}
|
| 113 |
]
|
| 114 |
|
| 115 |
# 2. 应用对话模板
|
| 116 |
-
#
|
| 117 |
-
#
|
| 118 |
-
#
|
| 119 |
-
|
| 120 |
messages,
|
| 121 |
-
|
| 122 |
-
add_generation_prompt=True
|
| 123 |
-
).to(model.device)
|
| 124 |
-
|
| 125 |
-
# 3. 获取输入 token 的长度
|
| 126 |
-
# 这里的 inputs 是 token IDs tensor, shape [1, sequence_length]
|
| 127 |
-
input_len = inputs.shape[1]
|
| 128 |
-
|
| 129 |
-
# 4. 获取 EOT (End of Text) token IDs
|
| 130 |
-
# Qwen 系列使用 <|im_end|> (151645) 和/或 <|endoftext|> (151643)
|
| 131 |
-
# tokenizer.eos_token_id 应该已经正确设置 (来自 Qwen1.5 基础模型)
|
| 132 |
-
eot_ids = [tokenizer.eos_token_id]
|
| 133 |
-
|
| 134 |
-
# 额外检查 <|im_end|>,确保它也在停止列表中
|
| 135 |
-
im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 136 |
-
if im_end_token_id != tokenizer.unk_token_id and im_end_token_id not in eot_ids:
|
| 137 |
-
eot_ids.append(im_end_token_id)
|
| 138 |
-
|
| 139 |
-
print(f"使用 EOT token IDs: {eot_ids} (eos_token_id: {tokenizer.eos_token_id})")
|
| 140 |
-
|
| 141 |
-
# 5. 生成
|
| 142 |
-
# 使用 eos_token_id 列表来确保模型在 <|im_end|> 处停止
|
| 143 |
-
outputs = model.generate(
|
| 144 |
-
input_ids=inputs,
|
| 145 |
-
max_new_tokens=2048,
|
| 146 |
-
eos_token_id=eot_ids, # <--- 关键:告诉模型何时停止
|
| 147 |
-
pad_token_id=tokenizer.eos_token_id # 避免 HF warning
|
| 148 |
)
|
| 149 |
|
| 150 |
-
#
|
| 151 |
-
#
|
| 152 |
-
|
| 153 |
|
| 154 |
-
#
|
| 155 |
-
#
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
print(f"生成结果 (仅新内容): {result}")
|
| 159 |
return result
|
|
|
|
| 160 |
except Exception as e:
|
| 161 |
-
print(f"推理时出错: {e}")
|
|
|
|
|
|
|
| 162 |
return f"错误: {e}"
|
| 163 |
|
| 164 |
# ----------------------------------------------------
|
| 165 |
|
|
|
|
| 166 |
example = """You are a senior SQL engineer. Your task is to generate a single, correct, and executable SQL query to answer the user's question based on the provided database context.
|
| 167 |
-
|
| 168 |
-
## INSTRUCTIONS
|
| 169 |
-
1. **Backend Adherence**: The query MUST be written for the `clickhouse` database backend. This is a strict requirement.
|
| 170 |
-
2. **Follow Special Notes**: You MUST strictly follow all syntax, functions, or constraints described in the [Database Backend Notes]. Pay extremely close attention to this section, as it contains critical, non-standard rules.
|
| 171 |
-
3. **Schema Integrity**: The query MUST ONLY use the tables and columns provided in the [Database Schema]. Do not invent or guess table or column names.
|
| 172 |
-
4. **Answer the Question**: The query must directly and accurately answer the [Natural Language Question].
|
| 173 |
-
5. **Output Format**: Enclose the final SQL query in a single Markdown code block formatted for SQL (` ```sql ... ``` `).
|
| 174 |
-
6. **Embedding Match**: If the [EMBEDDING_MODEL_NAME] parameter is a valid string (e.g., 'intfloat/E5-Mistral-7B-Instruct'), you MUST generate a query that includes the WHERE [EMBEDDING_COLUMN_NAME] MATCH lembed(...) clause for vector search. Otherwise, if embedding model name below the [EMBEDDING MODEL NAME] is None, , you MUST generate a standard SQL query that OMITS the entire MATCH lembed(...) clause. The query should not perform any vector search.
|
| 175 |
-
7. **Embedding Name**: If a value is provided for the parameter `[EMBEDDING_MODEL_NAME]`, your generated query must contain a `lembed` function call. The first parameter to the `lembed` function MUST be the exact value of `[EMBEDDING_MODEL_NAME]`, formatted as a string literal (enclosed in single quotes). For example, if `[EMBEDDING_MODEL_NAME]` is `laion/CLIP-ViT-B-32-laion2B-s34B-b79K`, the generated SQL must include `MATCH lembed('laion/CLIP-ViT-B-32-laion2B-s34B-b79K', ...)`.
|
| 176 |
-
|
| 177 |
-
## DATABASE CONTEXT
|
| 178 |
-
|
| 179 |
-
[DATABASE BACKEND]:
|
| 180 |
-
clickhouse
|
| 181 |
-
|
| 182 |
-
[DATABASE SCHEMA]:
|
| 183 |
-
CREATE TABLE CAMPAIGN_RESULTS (
|
| 184 |
-
`result_id` Nullable(Int64),
|
| 185 |
-
`campaign_id` Nullable(Int64),
|
| 186 |
-
`territory_id` Nullable(Int64),
|
| 187 |
-
`menu_item_id` Nullable(Int64),
|
| 188 |
-
`sales_increase_percentage` Nullable(Float64),
|
| 189 |
-
`customer_engagement_score` Nullable(Float64),
|
| 190 |
-
`feedback_improvement` Nullable(Float64)
|
| 191 |
-
);
|
| 192 |
-
CREATE TABLE CUSTOMERS (
|
| 193 |
-
`customer_id` Nullable(Int64),
|
| 194 |
-
`customer_name` Nullable(String),
|
| 195 |
-
`email` Nullable(String),
|
| 196 |
-
`phone_number` Nullable(String),
|
| 197 |
-
`loyalty_points` Nullable(Int64)
|
| 198 |
-
);
|
| 199 |
-
CREATE TABLE CUSTOMER_FEEDBACK (
|
| 200 |
-
`feedback_id` Nullable(Int64),
|
| 201 |
-
`menu_item_id` Nullable(Int64),
|
| 202 |
-
`customer_id` Nullable(Int64),
|
| 203 |
-
`feedback_date` Nullable(String),
|
| 204 |
-
`rating` Nullable(Float64),
|
| 205 |
-
`comments` Nullable(String),
|
| 206 |
-
`feedback_type` Nullable(String),
|
| 207 |
-
`comments_embedding` Array(Float32)
|
| 208 |
-
);
|
| 209 |
-
CREATE TABLE INGREDIENTS (
|
| 210 |
-
`ingredient_id` Nullable(Int64),
|
| 211 |
-
`name` Nullable(String),
|
| 212 |
-
`description` Nullable(String),
|
| 213 |
-
`supplier_id` Nullable(Int64),
|
| 214 |
-
`cost_per_unit` Nullable(Float64),
|
| 215 |
-
`description_embedding` Array(Float32)
|
| 216 |
-
);
|
| 217 |
-
CREATE TABLE MARKETING_CAMPAIGNS (
|
| 218 |
-
`campaign_id` Nullable(Int64),
|
| 219 |
-
`campaign_name` Nullable(String),
|
| 220 |
-
`start_date` Nullable(String),
|
| 221 |
-
`end_date` Nullable(String),
|
| 222 |
-
`budget` Nullable(Float64),
|
| 223 |
-
`objective` Nullable(String),
|
| 224 |
-
`territory_id` Nullable(Int64)
|
| 225 |
-
);
|
| 226 |
-
CREATE TABLE MENU_CATEGORIES (
|
| 227 |
-
`category_id` Nullable(Int64),
|
| 228 |
-
`category_name` Nullable(String),
|
| 229 |
-
`description` Nullable(String),
|
| 230 |
-
`parent_category_id` Nullable(Int64),
|
| 231 |
-
`description_embedding` Array(Float32)
|
| 232 |
-
);
|
| 233 |
-
CREATE TABLE MENU_ITEMS (
|
| 234 |
-
`menu_item_id` Nullable(Int64),
|
| 235 |
-
`territory_id` Nullable(Int64),
|
| 236 |
-
`name` Nullable(String),
|
| 237 |
-
`price_inr` Nullable(Float64),
|
| 238 |
-
`price_usd` Nullable(Float64),
|
| 239 |
-
`price_eur` Nullable(Float64),
|
| 240 |
-
`category_id` Nullable(Int64),
|
| 241 |
-
`menu_type` Nullable(String),
|
| 242 |
-
`calories` Nullable(Int64),
|
| 243 |
-
`is_vegetarian` Nullable(Int64),
|
| 244 |
-
`promotion_id` Nullable(Int64)
|
| 245 |
-
);
|
| 246 |
-
CREATE TABLE MENU_ITEM_INGREDIENTS (
|
| 247 |
-
`menu_item_id` Nullable(Int64),
|
| 248 |
-
`ingredient_id` Nullable(Int64),
|
| 249 |
-
`quantity` Nullable(Float64),
|
| 250 |
-
`unit_of_measurement` Nullable(String)
|
| 251 |
-
);
|
| 252 |
-
CREATE TABLE PRICING_STRATEGIES (
|
| 253 |
-
`strategy_id` Nullable(Int64),
|
| 254 |
-
`strategy_name` Nullable(String),
|
| 255 |
-
`description` Nullable(String),
|
| 256 |
-
`territory_id` Nullable(Int64),
|
| 257 |
-
`effective_date` Nullable(String),
|
| 258 |
-
`end_date` Nullable(String),
|
| 259 |
-
`description_embedding` Array(Float32)
|
| 260 |
-
);
|
| 261 |
-
CREATE TABLE PROMOTIONS (
|
| 262 |
-
`promotion_id` Nullable(Int64),
|
| 263 |
-
`promotion_name` Nullable(String),
|
| 264 |
-
`start_date` Nullable(String),
|
| 265 |
-
`end_date` Nullable(String),
|
| 266 |
-
`discount_percentage` Nullable(Float64),
|
| 267 |
-
`category_id` Nullable(Int64),
|
| 268 |
-
`territory_id` Nullable(Int64)
|
| 269 |
-
);
|
| 270 |
-
CREATE TABLE SALES_DATA (
|
| 271 |
-
`sale_id` Nullable(Int64),
|
| 272 |
-
`menu_item_id` Nullable(Int64),
|
| 273 |
-
`territory_id` Nullable(Int64),
|
| 274 |
-
`sale_date` Nullable(String),
|
| 275 |
-
`quantity_sold` Nullable(Int64),
|
| 276 |
-
`total_revenue` Nullable(Float64),
|
| 277 |
-
`discount_applied` Nullable(Float64),
|
| 278 |
-
`customer_id` Nullable(Int64)
|
| 279 |
-
);
|
| 280 |
-
CREATE TABLE SALES_FORECAST (
|
| 281 |
-
`forecast_id` Nullable(Int64),
|
| 282 |
-
`menu_item_id` Nullable(Int64),
|
| 283 |
-
`territory_id` Nullable(Int64),
|
| 284 |
-
`forecast_date` Nullable(String),
|
| 285 |
-
`forecast_quantity` Nullable(Int64),
|
| 286 |
-
`forecast_revenue` Nullable(Float64),
|
| 287 |
-
`prediction_accuracy` Nullable(Float64)
|
| 288 |
-
);
|
| 289 |
-
CREATE TABLE SUPPLIERS (
|
| 290 |
-
`supplier_id` Nullable(Int64),
|
| 291 |
-
`supplier_name` Nullable(String),
|
| 292 |
-
`contact_email` Nullable(String),
|
| 293 |
-
`phone_number` Nullable(String),
|
| 294 |
-
`address` Nullable(String)
|
| 295 |
-
);
|
| 296 |
-
CREATE TABLE TERRITORIES (
|
| 297 |
-
`territory_id` Nullable(Int64),
|
| 298 |
-
`territory_name` Nullable(String),
|
| 299 |
-
`region` Nullable(String),
|
| 300 |
-
`contact_email` Nullable(String),
|
| 301 |
-
`local_tax_rate` Nullable(Float64),
|
| 302 |
-
`currency_code` Nullable(String)
|
| 303 |
-
);
|
| 304 |
-
CREATE TABLE USERS (
|
| 305 |
-
`user_id` Nullable(Int64),
|
| 306 |
-
`user_name` Nullable(String),
|
| 307 |
-
`email` Nullable(String),
|
| 308 |
-
`role_id` Nullable(Int64),
|
| 309 |
-
`territory_id` Nullable(Int64)
|
| 310 |
-
);
|
| 311 |
-
CREATE TABLE USER_ROLES (
|
| 312 |
-
`role_id` Nullable(Int64),
|
| 313 |
-
`role_name` Nullable(String),
|
| 314 |
-
`description` Nullable(String),
|
| 315 |
-
`permissions` Nullable(String),
|
| 316 |
-
`description_embedding` Array(Float32)
|
| 317 |
-
);
|
| 318 |
-
|
| 319 |
-
[DATABASE BACKEND NOTES]:
|
| 320 |
-
There are a few requirements you should comply with in addition:
|
| 321 |
-
1. When generating SQL queries, you should prioritize utilizing K-Nearest Neighbor (KNN) searches whenever contextually appropriate. However, you must avoid unnecessary/forced KNN implementations for:
|
| 322 |
-
-- Traditional relational data queries (especially for columns like: id, age, price).
|
| 323 |
-
-- Cases where standard SQL operators (equality, range, or aggregation functions) are more efficient and semantically appropriate.
|
| 324 |
-
2. Only columns with a vector type (like: Array(Float32)) support KNN queries. The names of these vector columns often end with "_embedding". You can perform KNN searches when the column name you need to query ends with "_embedding" or is otherwise identified as a vector column.
|
| 325 |
-
3. In ClickHouse, vector similarity search is performed using distance functions. You must explicitly calculate the distance in the SELECT clause using a function like L2Distance and give it an alias, typically "AS distance". This distance alias will not be implicitly generated.
|
| 326 |
-
4. The lembed function is used to transform a string into a semantic vector. This function should be used within a WITH clause to define the reference vector. The lembed function has two parameters: the first is the name of the embedding model used (default value: 'intfloat/E5-Mistral-7B-Instruct'), and the second is the string content to embed. The resulting vector should be given an alias in the WITH clause.
|
| 327 |
-
5. You must generate plausible and semantically relevant words or sentences for the second parameter of the lembed function based on the column's name, type, and comment. For example, if a column is named product_description_embedding and its comment is "Embedding of the product's features and marketing text", you could generate text like "durable and waterproof outdoor adventure camera".
|
| 328 |
-
6. Every KNN search query MUST conclude with "ORDER BY distance LIMIT N" to retrieve the top-N most similar results. The LIMIT clause is mandatory for performing a KNN search and ensuring predictable performance.
|
| 329 |
-
7. When combining a vector search with JOIN operations, the standard WHERE clause should be used to apply filters from any of the joined tables. The ORDER BY distance LIMIT N clause is applied after all filtering and joins are resolved.
|
| 330 |
-
8. A SELECT statement should typically be ordered by a single distance calculation to perform one primary KNN search. However, subqueries can perform their own independent KNN searches, each with its own WITH clause, distance calculation, and ORDER BY distance LIMIT N clause.
|
| 331 |
-
|
| 332 |
-
## Example of a ClickHouse KNN Query
|
| 333 |
-
DB Schema: Some table on articles with a column content_embedding Array(Float32).
|
| 334 |
-
Query Task: Identify the article ID of the single most relevant article discussing innovative algorithms in graph theory.
|
| 335 |
-
Generated SQL:
|
| 336 |
-
```sql
|
| 337 |
-
WITH\n lembed('intfloat/E5-Mistral-7B-Instruct', 'innovative algorithms in graph theory.') AS ref_vec_0\n\nSELECT id, L2Distance(articles.abstract_embedding, ref_vec_0) AS distance\nFROM articles\nORDER BY distance\nLIMIT 1;
|
| 338 |
-
```
|
| 339 |
-
|
| 340 |
-
[EMBEDDING MODEL NAME]:
|
| 341 |
-
intfloat/E5-Mistral-7B-Instruct
|
| 342 |
-
|
| 343 |
-
## NATURAL LANGUAGE QUESTION
|
| 344 |
-
In vector searches, the `MATCH` operator performs an approximate nearest neighbor (ANN) search, which identifies items based on their similarity to a given vector. The `lembed()` function is used to convert text phrases into vector representations using a specific model, in this case, `intfloat/E5-Mistral-7B-Instruct`. This helps in finding items that align closely with the concept of "Popular menu items based on sales." The parameter `k = 5` specifies that only the top 5 categories, which are most similar in terms of the embedding, should be considered. The similarity is determined by calculating the Euclidean distance between vectors, where a smaller distance indicates higher similarity.
|
| 345 |
-
Can you unveil the crown jewel of our vegetarian delights, the one that has soared to the top of the sales charts from the elite circle of our most cherished categories this year?
|
| 346 |
-
|
| 347 |
Let's think step by step!
|
| 348 |
"""
|
|
|
|
|
|
|
| 349 |
|
| 350 |
-
examples = [example]
|
| 351 |
|
| 352 |
-
# 6. 创建 Gradio 界面 (
|
| 353 |
iface = gr.Interface(
|
| 354 |
-
fn=inference,
|
| 355 |
-
|
| 356 |
-
inputs=gr.Textbox(lines=10, label="输入查询 (Input Query)"),
|
| 357 |
outputs=gr.Textbox(lines=10, label="模型输出 (Model Output)"),
|
| 358 |
-
title="UniVectorSQL-7B
|
| 359 |
# description="这是一个 Text-to-SQL 模型。请输入您的问题 (Question) 和数据库模式 (Schema)。点击下方示例尝试。",
|
| 360 |
-
examples=examples
|
| 361 |
)
|
| 362 |
# ---------------------
|
| 363 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
+
# 导入 vllm 相关的库
|
| 4 |
+
from vllm import LLM, SamplingParams
|
| 5 |
+
# 保持 AutoTokenizer 用于处理聊天模板
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
from modelscope.hub.snapshot_download import snapshot_download
|
| 8 |
import os
|
| 9 |
+
import traceback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# --- 关键配置 ---
|
| 12 |
|
| 13 |
+
# !!! (重要) 请在此处填入您在 ModelScope 上的 "已合并" 模型的 ID
|
| 14 |
+
# vLLM 将直接加载这个完整的模型
|
| 15 |
+
#
|
| 16 |
+
# 我使用了您原始的基础模型 ID 作为示例,但您*必须*将其替换为
|
| 17 |
+
# 已经合并了 LoRA (UniVectorSQL) 的那个模型的 ID。
|
| 18 |
+
MERGED_MODEL_ID = "zrwang/UniVectorSQL-7B-LoRA-merged"
|
| 19 |
+
# 示例:如果您的合并后模型叫 "risemds/UniVectorSQL-7B-Merged",请修改上面一行
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# -----------------
|
| 22 |
|
| 23 |
|
| 24 |
+
# --- vLLM 模型加载 ---
|
| 25 |
+
|
| 26 |
+
# 1. 下载模型
|
| 27 |
# 我们将忽略 .md 文件和 training_args.bin 文件
|
| 28 |
+
print(f"开始下载已合并的模型: {MERGED_MODEL_ID}")
|
| 29 |
+
model_dir = snapshot_download(
|
| 30 |
+
MERGED_MODEL_ID,
|
| 31 |
revision='master',
|
| 32 |
+
ignore_patterns=["*.md", "training_args.bin", "checkpoint-*"]
|
| 33 |
)
|
| 34 |
+
print(f"模型下载完成,路径: {model_dir}")
|
| 35 |
|
| 36 |
+
|
| 37 |
+
# 2. 加载 Tokenizer
|
| 38 |
+
# 我们需要 Tokenizer 来应用聊天模板
|
| 39 |
try:
|
| 40 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
| 41 |
+
print("从模型目录加载 Tokenizer 成功。")
|
| 42 |
except Exception as e:
|
| 43 |
+
print(f"从模型目录加载 Tokenizer 失败: {e}")
|
| 44 |
+
# 如果失败,程序无法继续,因为 vLLM 需要知道 EOT token
|
| 45 |
+
raise e
|
| 46 |
|
| 47 |
+
# --- (保留) 修复 Tokenizer 的 pad_token_id ---
|
| 48 |
+
# Qwen1.5 基础模型可能没有设置 pad_token_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
if tokenizer.pad_token_id is None:
|
| 50 |
print("Tokenizer 未设置 pad_token_id,将其设置为 eos_token_id。")
|
| 51 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 52 |
+
# (注意:我们不再需要修复 model.config,因为 vLLM 会处理)
|
|
|
|
|
|
|
|
|
|
| 53 |
# ----------------------------------------
|
| 54 |
|
| 55 |
|
| 56 |
+
# 3. 加载 vLLM 模型
|
| 57 |
+
print(f"开始加载 vLLM 模型: {model_dir}")
|
| 58 |
+
|
| 59 |
+
# 自动检测 GPU 数量以设置 tensor_parallel_size
|
| 60 |
+
if torch.cuda.is_available():
|
| 61 |
+
gpu_count = torch.cuda.device_count()
|
| 62 |
+
print(f"检测到 {gpu_count} 个 CUDA GPU。")
|
| 63 |
+
else:
|
| 64 |
+
print("警告: 未检测到 CUDA GPU。vLLM 强烈建议在 GPU 上运行。")
|
| 65 |
+
gpu_count = 1 # 假设至少为 1,vLLM 0.4.0+ 也支持 CPU (但很慢)
|
| 66 |
+
|
| 67 |
+
llm = LLM(
|
| 68 |
+
model=model_dir,
|
| 69 |
+
trust_remote_code=True,
|
| 70 |
+
tensor_parallel_size=gpu_count, # 自动使用所有可用的 GPU
|
| 71 |
+
dtype="auto" # vLLM 会自动选择 (例如 bfloat16 或 float16)
|
| 72 |
+
# max_model_len=4096 # (可选) 如果需要,设置最大上下文长度
|
| 73 |
)
|
| 74 |
+
print("vLLM 模型加载完成。")
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
# 4. 定义推理函数 (*** 已修改为使用 vLLM ***)
|
| 78 |
+
|
| 79 |
+
# --- 提前准备 EOT token 和 SamplingParams ---
|
| 80 |
+
# 1. 获取 EOT (End of Text) token IDs (用于停止生成)
|
| 81 |
+
# Qwen 系列使用 <|im_end|> (151645) 和/或 <|endoftext|> (151643)
|
| 82 |
+
eot_ids = [tokenizer.eos_token_id]
|
| 83 |
+
|
| 84 |
+
im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 85 |
+
if im_end_token_id != tokenizer.unk_token_id and im_end_token_id not in eot_ids:
|
| 86 |
+
eot_ids.append(im_end_token_id)
|
| 87 |
+
|
| 88 |
+
print(f"vLLM 将使用 stop_token_ids: {eot_ids} (eos_token_id: {tokenizer.eos_token_id})")
|
| 89 |
+
|
| 90 |
+
# 2. 创建 vLLM SamplingParams
|
| 91 |
+
# (原 HFace generate 参数:max_new_tokens=2048, eos_token_id=eot_ids)
|
| 92 |
+
sampling_params = SamplingParams(
|
| 93 |
+
max_tokens=2048, # 对应 HFace 的 max_new_tokens
|
| 94 |
+
stop_token_ids=eot_ids, # <--- 关键:告诉 vLLM 何时停止
|
| 95 |
+
temperature=0.0, # 对于 SQL 生成,使用贪婪采样 (0.0) 通常是最好的
|
| 96 |
+
top_p=1.0, # (配合 temperature=0.0)
|
| 97 |
+
)
|
| 98 |
+
# ----------------------------------------
|
| 99 |
|
|
|
|
| 100 |
def inference(text_input):
|
| 101 |
print(f"收到输入: {text_input}")
|
| 102 |
try:
|
| 103 |
# 1. 构建 Qwen (ChatML) 对话格式
|
|
|
|
| 104 |
messages = [
|
| 105 |
{"role": "user", "content": text_input}
|
| 106 |
]
|
| 107 |
|
| 108 |
# 2. 应用对话模板
|
| 109 |
+
# vLLM 需要的是 *字符串*,而不是 token IDs
|
| 110 |
+
# tokenize=False 会返回格式化后的字符串
|
| 111 |
+
# add_generation_prompt=True 会在末尾添加 <|im_start|>assistant\n
|
| 112 |
+
prompt_str = tokenizer.apply_chat_template(
|
| 113 |
messages,
|
| 114 |
+
tokenize=False, # <--- 关键:返回字符串
|
| 115 |
+
add_generation_prompt=True # <--- 关键:添加 assistant 提示
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
)
|
| 117 |
|
| 118 |
+
# 3. vLLM 生成
|
| 119 |
+
# llm.generate 接受一个 prompt 列表
|
| 120 |
+
outputs = llm.generate([prompt_str], sampling_params)
|
| 121 |
|
| 122 |
+
# 4. 提取结果
|
| 123 |
+
# outputs 是一个列表,对应输入的 prompt 列表
|
| 124 |
+
# outputs[0] 是第一个 prompt 的 RequestOutput
|
| 125 |
+
# outputs[0].outputs[0] 是第一个 (best_of=1) 的生成结果
|
| 126 |
+
# .text 包含 *新生成* 的文本 (不包含 prompt)
|
| 127 |
+
result = outputs[0].outputs[0].text.strip()
|
| 128 |
|
| 129 |
print(f"生成结果 (仅新内容): {result}")
|
| 130 |
return result
|
| 131 |
+
|
| 132 |
except Exception as e:
|
| 133 |
+
print(f"vLLM 推理时出错: {e}")
|
| 134 |
+
# 打印更详细的 vLLM 错误
|
| 135 |
+
traceback.print_exc()
|
| 136 |
return f"错误: {e}"
|
| 137 |
|
| 138 |
# ----------------------------------------------------
|
| 139 |
|
| 140 |
+
# (示例保持不变)
|
| 141 |
example = """You are a senior SQL engineer. Your task is to generate a single, correct, and executable SQL query to answer the user's question based on the provided database context.
|
| 142 |
+
... (示例内容和之前一样) ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
Let's think step by step!
|
| 144 |
"""
|
| 145 |
+
# (为了简洁,省略了示例的完整内容,使用您原始的 'example' 变量即可)
|
| 146 |
+
examples = [example]
|
| 147 |
|
|
|
|
| 148 |
|
| 149 |
+
# 6. 创建 Gradio 界面 (保持不变)
|
| 150 |
iface = gr.Interface(
|
| 151 |
+
fn=inference,
|
| 152 |
+
inputs=gr.Textbox(lines=10, label="输入查询 (Input Query)"),
|
|
|
|
| 153 |
outputs=gr.Textbox(lines=10, label="模型输出 (Model Output)"),
|
| 154 |
+
title="UniVectorSQL-7B (vLLM 推理)",
|
| 155 |
# description="这是一个 Text-to-SQL 模型。请输入您的问题 (Question) 和数据库模式 (Schema)。点击下方示例尝试。",
|
| 156 |
+
examples=examples
|
| 157 |
)
|
| 158 |
# ---------------------
|
| 159 |
|
requirements.txt
CHANGED
|
@@ -2,4 +2,5 @@ torch
|
|
| 2 |
transformers
|
| 3 |
peft
|
| 4 |
modelscope
|
| 5 |
-
gradio
|
|
|
|
|
|
| 2 |
transformers
|
| 3 |
peft
|
| 4 |
modelscope
|
| 5 |
+
gradio
|
| 6 |
+
vllm
|