File size: 17,147 Bytes
5cf6351
68aa6fa
 
 
 
d0ac527
5cf6351
d0ac527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cf6351
d0ac527
 
 
 
 
 
68aa6fa
5cf6351
d0ac527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cf6351
 
d0ac527
 
 
 
 
 
 
 
 
5cf6351
d0ac527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cf6351
68aa6fa
5cf6351
d0ac527
68aa6fa
d0ac527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cf6351
d0ac527
 
68aa6fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from modelscope.hub.snapshot_download import snapshot_download
import os

# --- 自动检测设备 ---
# 检查是否有可用的 GPU (NVIDIA CUDA)
if torch.cuda.is_available():
    device = "cuda"
    torch_dtype = torch.float16  # GPU 使用 float16
    device_map = "auto"         # 自动分配到 GPU
    print("检测到 CUDA (GPU),将使用 GPU。")
# 检查是否有可用的 Apple Silicon (MPS)
elif torch.backends.mps.is_available():
    device = "mps"
    torch_dtype = torch.float16  # Apple Silicon GPU 也可以用 float16
    device_map = "mps"          # 明确指定 mps
    print("检测到 MPS (Apple Silicon),将使用 MPS。")
# 否则回退到 CPU
else:
    device = "cpu"
    torch_dtype = torch.float32  # CPU 使用 float32 (float16 在 CPU 上很慢或不支持)
    device_map = "cpu"          # 明确指定 cpu
    print("未检测到 GPU,将使用 CPU。")
# ---------------------


# --- 关键配置 ---

# 1. 您的 LoRA 模型 ID
# LORA_MODEL_ID = "/mnt/DataFlow/ydw/model/risemds/UniVectorSQL-7B-LoRA-all_steps_1030_new_data/"
# lora_model_dir = LORA_MODEL_ID
LORA_MODEL_ID = "risemds/UniVectorSQL-7B-LoRA-all_steps_1030_new_data"

# 2. 基础模型 ID (!!!)
# 您必须找到这个 LoRA 对应的基础模型是什么。
# BASE_MODEL_ID = "/mnt/DataFlow/ydw/model/seeklhy/OmniSQL-7B/"
BASE_MODEL_ID = "seeklhy/OmniSQL-7B" 

# -----------------


# 使用 ignore_patterns 避免下载不需要的 checkpoint/文件
# 我们将忽略 .md 文件和 training_args.bin 文件
print(f"开始下载 LoRA 适配器: {LORA_MODEL_ID}")
lora_model_dir = snapshot_download(
    LORA_MODEL_ID,
    revision='master',
    ignore_patterns=["*.md", "training_args.bin", "checkpoint-*"] 
)
print(f"LoRA 适配器下载完成,路径: {lora_model_dir}")

# 1. 加载 Tokenizer
# 优先使用 LoRA 仓库中的 Tokenizer,因为它可能已更新(例如添加了新 token)
try:
    tokenizer = AutoTokenizer.from_pretrained(lora_model_dir, trust_remote_code=True)
    print("从 LoRA 目录加载 Tokenizer 成功。")
except Exception as e:
    print(f"从 LoRA 目录加载 Tokenizer 失败 ({e}), 尝试从基础模型加载。")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)


# 2. 加载基础模型
print(f"开始加载基础模型: {BASE_MODEL_ID}")
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    device_map=device_map,      # <--- 修改:使用自动检测的 device_map
    torch_dtype=torch_dtype,    # <--- 修改:使用自动检测的 torch_dtype
    trust_remote_code=True
)
print("基础模型加载完成。")

# --- (新增) 修复 Qwen 模型的 pad_token_id ---
# Qwen1.5 基础模型可能没有设置 pad_token_id,这会在推理时产生警告
# 我们将其设置为 eos_token_id
if tokenizer.pad_token_id is None:
    print("Tokenizer 未设置 pad_token_id,将其设置为 eos_token_id。")
    tokenizer.pad_token_id = tokenizer.eos_token_id
    
if model.config.pad_token_id is None:
    print("Model config 未设置 pad_token_id,将其设置为 eos_token_id。")
    model.config.pad_token_id = tokenizer.eos_token_id
# ----------------------------------------


# 3. 加载并融合 LoRA 适配器
print(f"开始加载 LoRA 适配器到基础模型...")
model = PeftModel.from_pretrained(
    model, 
    lora_model_dir,
    device_map=device_map       # <--- 修改:传递自动检测的 device_map
)
print("LoRA 适配器加载成功。")

# (可选) 合并权重以加快推理速度,但这会占用更多内存
# 注意:在 CPU 上合并可能非常慢
print("正在合并 LoRA 权重...")
model = model.merge_and_unload()
print("权重合并完成。")

model.eval()

# 4. 定义推理函数 (*** 已修改为使用 Qwen 对话模板 ***)
def inference(text_input):
    print(f"收到输入: {text_input}")
    try:
        # 1. 构建 Qwen (ChatML) 对话格式
        # 假设 text_input 是完整的用户提示 (包含指令、Schema等)
        messages = [
            {"role": "user", "content": text_input}
        ]
        
        # 2. 应用对话模板
        # tokenizer.apply_chat_template 会自动处理特殊 tokens (例如 <|im_start|>)
        # 并且 add_generation_prompt=True 会添加 <|im_start|>assistant\n
        # 这会告诉模型开始生成回复
        inputs = tokenizer.apply_chat_template(
            messages,
            return_tensors="pt",
            add_generation_prompt=True
        ).to(model.device)
        
        # 3. 获取输入 token 的长度
        # 这里的 inputs 是 token IDs tensor, shape [1, sequence_length]
        input_len = inputs.shape[1]
        
        # 4. 获取 EOT (End of Text) token IDs
        # Qwen 系列使用 <|im_end|> (151645) 和/或 <|endoftext|> (151643)
        # tokenizer.eos_token_id 应该已经正确设置 (来自 Qwen1.5 基础模型)
        eot_ids = [tokenizer.eos_token_id]
        
        # 额外检查 <|im_end|>,确保它也在停止列表中
        im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
        if im_end_token_id != tokenizer.unk_token_id and im_end_token_id not in eot_ids:
            eot_ids.append(im_end_token_id)
        
        print(f"使用 EOT token IDs: {eot_ids} (eos_token_id: {tokenizer.eos_token_id})")

        # 5. 生成
        # 使用 eos_token_id 列表来确保模型在 <|im_end|> 处停止
        outputs = model.generate(
            input_ids=inputs,
            max_new_tokens=2048,
            eos_token_id=eot_ids, # <--- 关键:告诉模型何时停止
            pad_token_id=tokenizer.eos_token_id # 避免 HF warning
        )
        
        # 6. 从输出的 token 序列中,只选择新生成的部分
        # outputs[0] 包含 (输入 + 生成) 的所有 tokens
        new_tokens = outputs[0][input_len:]
        
        # 7. 只解码新生成的 token
        # skip_special_tokens=True 会移除任何 <|im_end|>
        result = tokenizer.decode(new_tokens, skip_special_tokens=True)

        print(f"生成结果 (仅新内容): {result}")
        return result
    except Exception as e:
        print(f"推理时出错: {e}")
        return f"错误: {e}"

# ----------------------------------------------------

example = """You are a senior SQL engineer. Your task is to generate a single, correct, and executable SQL query to answer the user's question based on the provided database context.

## INSTRUCTIONS
1.  **Backend Adherence**: The query MUST be written for the `clickhouse` database backend. This is a strict requirement.
2.  **Follow Special Notes**: You MUST strictly follow all syntax, functions, or constraints described in the [Database Backend Notes]. Pay extremely close attention to this section, as it contains critical, non-standard rules.
3.  **Schema Integrity**: The query MUST ONLY use the tables and columns provided in the [Database Schema]. Do not invent or guess table or column names.
4.  **Answer the Question**: The query must directly and accurately answer the [Natural Language Question].
5.  **Output Format**: Enclose the final SQL query in a single Markdown code block formatted for SQL (` ```sql ... ``` `).
6.  **Embedding Match**: If the [EMBEDDING_MODEL_NAME] parameter is a valid string (e.g., 'intfloat/E5-Mistral-7B-Instruct'), you MUST generate a query that includes the WHERE [EMBEDDING_COLUMN_NAME] MATCH lembed(...) clause for vector search. Otherwise, if embedding model name below the [EMBEDDING MODEL NAME] is None, , you MUST generate a standard SQL query that OMITS the entire MATCH lembed(...) clause. The query should not perform any vector search.
7.  **Embedding Name**: If a value is provided for the parameter `[EMBEDDING_MODEL_NAME]`, your generated query must contain a `lembed` function call. The first parameter to the `lembed` function MUST be the exact value of `[EMBEDDING_MODEL_NAME]`, formatted as a string literal (enclosed in single quotes). For example, if `[EMBEDDING_MODEL_NAME]` is `laion/CLIP-ViT-B-32-laion2B-s34B-b79K`, the generated SQL must include `MATCH lembed('laion/CLIP-ViT-B-32-laion2B-s34B-b79K', ...)`.

## DATABASE CONTEXT

[DATABASE BACKEND]:
clickhouse

[DATABASE SCHEMA]:
CREATE TABLE CAMPAIGN_RESULTS (
  `result_id` Nullable(Int64),
  `campaign_id` Nullable(Int64),
  `territory_id` Nullable(Int64),
  `menu_item_id` Nullable(Int64),
  `sales_increase_percentage` Nullable(Float64),
  `customer_engagement_score` Nullable(Float64),
  `feedback_improvement` Nullable(Float64)
);
CREATE TABLE CUSTOMERS (
  `customer_id` Nullable(Int64),
  `customer_name` Nullable(String),
  `email` Nullable(String),
  `phone_number` Nullable(String),
  `loyalty_points` Nullable(Int64)
);
CREATE TABLE CUSTOMER_FEEDBACK (
  `feedback_id` Nullable(Int64),
  `menu_item_id` Nullable(Int64),
  `customer_id` Nullable(Int64),
  `feedback_date` Nullable(String),
  `rating` Nullable(Float64),
  `comments` Nullable(String),
  `feedback_type` Nullable(String),
  `comments_embedding` Array(Float32)
);
CREATE TABLE INGREDIENTS (
  `ingredient_id` Nullable(Int64),
  `name` Nullable(String),
  `description` Nullable(String),
  `supplier_id` Nullable(Int64),
  `cost_per_unit` Nullable(Float64),
  `description_embedding` Array(Float32)
);
CREATE TABLE MARKETING_CAMPAIGNS (
  `campaign_id` Nullable(Int64),
  `campaign_name` Nullable(String),
  `start_date` Nullable(String),
  `end_date` Nullable(String),
  `budget` Nullable(Float64),
  `objective` Nullable(String),
  `territory_id` Nullable(Int64)
);
CREATE TABLE MENU_CATEGORIES (
  `category_id` Nullable(Int64),
  `category_name` Nullable(String),
  `description` Nullable(String),
  `parent_category_id` Nullable(Int64),
  `description_embedding` Array(Float32)
);
CREATE TABLE MENU_ITEMS (
  `menu_item_id` Nullable(Int64),
  `territory_id` Nullable(Int64),
  `name` Nullable(String),
  `price_inr` Nullable(Float64),
  `price_usd` Nullable(Float64),
  `price_eur` Nullable(Float64),
  `category_id` Nullable(Int64),
  `menu_type` Nullable(String),
  `calories` Nullable(Int64),
  `is_vegetarian` Nullable(Int64),
  `promotion_id` Nullable(Int64)
);
CREATE TABLE MENU_ITEM_INGREDIENTS (
  `menu_item_id` Nullable(Int64),
  `ingredient_id` Nullable(Int64),
  `quantity` Nullable(Float64),
  `unit_of_measurement` Nullable(String)
);
CREATE TABLE PRICING_STRATEGIES (
  `strategy_id` Nullable(Int64),
  `strategy_name` Nullable(String),
  `description` Nullable(String),
  `territory_id` Nullable(Int64),
  `effective_date` Nullable(String),
  `end_date` Nullable(String),
  `description_embedding` Array(Float32)
);
CREATE TABLE PROMOTIONS (
  `promotion_id` Nullable(Int64),
  `promotion_name` Nullable(String),
  `start_date` Nullable(String),
  `end_date` Nullable(String),
  `discount_percentage` Nullable(Float64),
  `category_id` Nullable(Int64),
  `territory_id` Nullable(Int64)
);
CREATE TABLE SALES_DATA (
  `sale_id` Nullable(Int64),
  `menu_item_id` Nullable(Int64),
  `territory_id` Nullable(Int64),
  `sale_date` Nullable(String),
  `quantity_sold` Nullable(Int64),
  `total_revenue` Nullable(Float64),
  `discount_applied` Nullable(Float64),
  `customer_id` Nullable(Int64)
);
CREATE TABLE SALES_FORECAST (
  `forecast_id` Nullable(Int64),
  `menu_item_id` Nullable(Int64),
  `territory_id` Nullable(Int64),
  `forecast_date` Nullable(String),
  `forecast_quantity` Nullable(Int64),
  `forecast_revenue` Nullable(Float64),
  `prediction_accuracy` Nullable(Float64)
);
CREATE TABLE SUPPLIERS (
  `supplier_id` Nullable(Int64),
  `supplier_name` Nullable(String),
  `contact_email` Nullable(String),
  `phone_number` Nullable(String),
  `address` Nullable(String)
);
CREATE TABLE TERRITORIES (
  `territory_id` Nullable(Int64),
  `territory_name` Nullable(String),
  `region` Nullable(String),
  `contact_email` Nullable(String),
  `local_tax_rate` Nullable(Float64),
  `currency_code` Nullable(String)
);
CREATE TABLE USERS (
  `user_id` Nullable(Int64),
  `user_name` Nullable(String),
  `email` Nullable(String),
  `role_id` Nullable(Int64),
  `territory_id` Nullable(Int64)
);
CREATE TABLE USER_ROLES (
  `role_id` Nullable(Int64),
  `role_name` Nullable(String),
  `description` Nullable(String),
  `permissions` Nullable(String),
  `description_embedding` Array(Float32)
);

[DATABASE BACKEND NOTES]:
There are a few requirements you should comply with in addition:
1. When generating SQL queries, you should prioritize utilizing K-Nearest Neighbor (KNN) searches whenever contextually appropriate. However, you must avoid unnecessary/forced KNN implementations for:
-- Traditional relational data queries (especially for columns like: id, age, price).
-- Cases where standard SQL operators (equality, range, or aggregation functions) are more efficient and semantically appropriate.
2. Only columns with a vector type (like: Array(Float32)) support KNN queries. The names of these vector columns often end with "_embedding". You can perform KNN searches when the column name you need to query ends with "_embedding" or is otherwise identified as a vector column.
3. In ClickHouse, vector similarity search is performed using distance functions. You must explicitly calculate the distance in the SELECT clause using a function like L2Distance and give it an alias, typically "AS distance". This distance alias will not be implicitly generated.
4. The lembed function is used to transform a string into a semantic vector. This function should be used within a WITH clause to define the reference vector. The lembed function has two parameters: the first is the name of the embedding model used (default value: 'intfloat/E5-Mistral-7B-Instruct'), and the second is the string content to embed. The resulting vector should be given an alias in the WITH clause.
5. You must generate plausible and semantically relevant words or sentences for the second parameter of the lembed function based on the column's name, type, and comment. For example, if a column is named product_description_embedding and its comment is "Embedding of the product's features and marketing text", you could generate text like "durable and waterproof outdoor adventure camera".
6. Every KNN search query MUST conclude with "ORDER BY distance LIMIT N" to retrieve the top-N most similar results. The LIMIT clause is mandatory for performing a KNN search and ensuring predictable performance.
7. When combining a vector search with JOIN operations, the standard WHERE clause should be used to apply filters from any of the joined tables. The ORDER BY distance LIMIT N clause is applied after all filtering and joins are resolved.
8. A SELECT statement should typically be ordered by a single distance calculation to perform one primary KNN search. However, subqueries can perform their own independent KNN searches, each with its own WITH clause, distance calculation, and ORDER BY distance LIMIT N clause.

## Example of a ClickHouse KNN Query
DB Schema: Some table on articles with a column content_embedding Array(Float32).
Query Task: Identify the article ID of the single most relevant article discussing innovative algorithms in graph theory.
Generated SQL:
```sql
    WITH\n    lembed('intfloat/E5-Mistral-7B-Instruct', 'innovative algorithms in graph theory.') AS ref_vec_0\n\nSELECT id, L2Distance(articles.abstract_embedding, ref_vec_0) AS distance\nFROM articles\nORDER BY distance\nLIMIT 1;
```

[EMBEDDING MODEL NAME]:
intfloat/E5-Mistral-7B-Instruct

## NATURAL LANGUAGE QUESTION
In vector searches, the `MATCH` operator performs an approximate nearest neighbor (ANN) search, which identifies items based on their similarity to a given vector. The `lembed()` function is used to convert text phrases into vector representations using a specific model, in this case, `intfloat/E5-Mistral-7B-Instruct`. This helps in finding items that align closely with the concept of "Popular menu items based on sales." The parameter `k = 5` specifies that only the top 5 categories, which are most similar in terms of the embedding, should be considered. The similarity is determined by calculating the Euclidean distance between vectors, where a smaller distance indicates higher similarity.
Can you unveil the crown jewel of our vegetarian delights, the one that has soared to the top of the sales charts from the elite circle of our most cherished categories this year?

Let's think step by step!
"""

examples = [example]

# 6. 创建 Gradio 界面 (原第5步)
iface = gr.Interface(
    fn=inference, 
    # 增加了行数以便容纳 Schema
    inputs=gr.Textbox(lines=10, label="输入查询 (Input Query)"), 
    outputs=gr.Textbox(lines=10, label="模型输出 (Model Output)"),
    title="UniVectorSQL-7B-LoRA 推理",
    # description="这是一个 Text-to-SQL 模型。请输入您的问题 (Question) 和数据库模式 (Schema)。点击下方示例尝试。",
    examples=examples  # <--- 添加示例
)
# ---------------------

print("启动 Gradio 界面...")
# 在 Hugging Face 或 ModelScope Space 中,share=True 不是必需的
iface.launch()