zr-wang commited on
Commit
d0ac527
·
1 Parent(s): 68aa6fa

remove extra download

Browse files
Files changed (1) hide show
  1. app.py +348 -21
app.py CHANGED
@@ -3,37 +3,364 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel
5
  from modelscope.hub.snapshot_download import snapshot_download
 
6
 
7
- # 1. 确定您的基础模型 (Base Model)
8
- BASE_MODEL_ID = "seeklhy/OmniSQL-7B"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # 2. 您的 LoRA 模型 ID
 
 
 
 
 
11
  LORA_MODEL_ID = "risemds/UniVectorSQL-7B-LoRA-all_steps_1030_new_data"
12
 
13
- # 3. 下载 ModelScope LoRA 权重
14
- # (您可能需要先登录 modelscope: `from modelscope.hub.api import HubApi; api = HubApi(); api.login('YOUR_TOKEN')`)
15
- lora_path = snapshot_download(LORA_MODEL_ID, revision='master') # 确保使用正确的 revision
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # 4. 加载基础模型和 Tokenizer
18
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
19
- model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
20
 
21
- # 5. 加载并融合 LoRA 适配器
22
- # PeftModel 会自动将 LoRA 权重加载到基础模型上
23
- model = PeftModel.from_pretrained(model, lora_path)
 
 
 
 
 
 
24
 
25
- # (可选) 如果需要,可以合并权重以加快推理
26
- # model = model.merge_and_unload()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  model.eval()
29
 
30
- # 6. 定义推理函数
31
  def inference(text_input):
32
- inputs = tokenizer(text_input, return_tensors="pt").to("cuda")
33
- outputs = model.generate(**inputs, max_new_tokens=100)
34
- result = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # 7. 创建 Gradio 界面
38
- iface = gr.Interface(fn=inference, inputs="text", outputs="text")
39
  iface.launch()
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel
5
  from modelscope.hub.snapshot_download import snapshot_download
6
+ import os
7
 
8
+ # --- 自动检测设备 ---
9
+ # 检查是否有可用的 GPU (NVIDIA CUDA)
10
+ if torch.cuda.is_available():
11
+ device = "cuda"
12
+ torch_dtype = torch.float16 # GPU 使用 float16
13
+ device_map = "auto" # 自动分配到 GPU
14
+ print("检测到 CUDA (GPU),将使用 GPU。")
15
+ # 检查是否有可用的 Apple Silicon (MPS)
16
+ elif torch.backends.mps.is_available():
17
+ device = "mps"
18
+ torch_dtype = torch.float16 # Apple Silicon GPU 也可以用 float16
19
+ device_map = "mps" # 明确指定 mps
20
+ print("检测到 MPS (Apple Silicon),将使用 MPS。")
21
+ # 否则回退到 CPU
22
+ else:
23
+ device = "cpu"
24
+ torch_dtype = torch.float32 # CPU 使用 float32 (float16 在 CPU 上很慢或不支持)
25
+ device_map = "cpu" # 明确指定 cpu
26
+ print("未检测到 GPU,将使用 CPU。")
27
+ # ---------------------
28
 
29
+
30
+ # --- 关键配置 ---
31
+
32
+ # 1. 您的 LoRA 模型 ID
33
+ # LORA_MODEL_ID = "/mnt/DataFlow/ydw/model/risemds/UniVectorSQL-7B-LoRA-all_steps_1030_new_data/"
34
+ # lora_model_dir = LORA_MODEL_ID
35
  LORA_MODEL_ID = "risemds/UniVectorSQL-7B-LoRA-all_steps_1030_new_data"
36
 
37
+ # 2. 基础模型 ID (!!!)
38
+ # 您必须找到这个 LoRA 对应的基础模型是什么。
39
+ # BASE_MODEL_ID = "/mnt/DataFlow/ydw/model/seeklhy/OmniSQL-7B/"
40
+ BASE_MODEL_ID = "seeklhy/OmniSQL-7B"
41
+
42
+ # -----------------
43
+
44
+
45
+ # 使用 ignore_patterns 避免下载不需要的 checkpoint/文件
46
+ # 我们将忽略 .md 文件和 training_args.bin 文件
47
+ print(f"开始下载 LoRA 适配器: {LORA_MODEL_ID}")
48
+ lora_model_dir = snapshot_download(
49
+ LORA_MODEL_ID,
50
+ revision='master',
51
+ ignore_patterns=["*.md", "training_args.bin", "checkpoint-*"]
52
+ )
53
+ print(f"LoRA 适配器下载完成,路径: {lora_model_dir}")
54
+
55
+ # 1. 加载 Tokenizer
56
+ # 优先使用 LoRA 仓库中的 Tokenizer,因为它可能已更新(例如添加了新 token)
57
+ try:
58
+ tokenizer = AutoTokenizer.from_pretrained(lora_model_dir, trust_remote_code=True)
59
+ print("从 LoRA 目录加载 Tokenizer 成功。")
60
+ except Exception as e:
61
+ print(f"从 LoRA 目录加载 Tokenizer 失败 ({e}), 尝试从基础模型加载。")
62
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
63
 
 
 
 
64
 
65
+ # 2. 加载基础模型
66
+ print(f"开始加载基础模型: {BASE_MODEL_ID}")
67
+ model = AutoModelForCausalLM.from_pretrained(
68
+ BASE_MODEL_ID,
69
+ device_map=device_map, # <--- 修改:使用自动检测的 device_map
70
+ torch_dtype=torch_dtype, # <--- 修改:使用自动检测的 torch_dtype
71
+ trust_remote_code=True
72
+ )
73
+ print("基础模型加载完成。")
74
 
75
+ # --- (新增) 修复 Qwen 模型的 pad_token_id ---
76
+ # Qwen1.5 基础模型可能没有设置 pad_token_id,这会在推理时产生警告
77
+ # 我们将其设置为 eos_token_id
78
+ if tokenizer.pad_token_id is None:
79
+ print("Tokenizer 未设置 pad_token_id,将其设置为 eos_token_id。")
80
+ tokenizer.pad_token_id = tokenizer.eos_token_id
81
+
82
+ if model.config.pad_token_id is None:
83
+ print("Model config 未设置 pad_token_id,将其设置为 eos_token_id。")
84
+ model.config.pad_token_id = tokenizer.eos_token_id
85
+ # ----------------------------------------
86
+
87
+
88
+ # 3. 加载并融合 LoRA 适配器
89
+ print(f"开始加载 LoRA 适配器到基础模型...")
90
+ model = PeftModel.from_pretrained(
91
+ model,
92
+ lora_model_dir,
93
+ device_map=device_map # <--- 修改:传递自动检测的 device_map
94
+ )
95
+ print("LoRA 适配器加载成功。")
96
+
97
+ # (可选) 合并权重以加快推理速度,但这会占用更多内存
98
+ # 注意:在 CPU 上合并可能非常慢
99
+ print("正在合并 LoRA 权重...")
100
+ model = model.merge_and_unload()
101
+ print("权重合并完成。")
102
 
103
  model.eval()
104
 
105
+ # 4. 定义推理函数 (*** 已修改为使用 Qwen 对话模板 ***)
106
  def inference(text_input):
107
+ print(f"收到输入: {text_input}")
108
+ try:
109
+ # 1. 构建 Qwen (ChatML) 对话格式
110
+ # 假设 text_input 是完整的用户提示 (包含指令、Schema等)
111
+ messages = [
112
+ {"role": "user", "content": text_input}
113
+ ]
114
+
115
+ # 2. 应用对话模板
116
+ # tokenizer.apply_chat_template 会自动处理特殊 tokens (例如 <|im_start|>)
117
+ # 并且 add_generation_prompt=True 会添加 <|im_start|>assistant\n
118
+ # 这会告诉模型开始生成回复
119
+ inputs = tokenizer.apply_chat_template(
120
+ messages,
121
+ return_tensors="pt",
122
+ add_generation_prompt=True
123
+ ).to(model.device)
124
+
125
+ # 3. 获取输入 token 的长度
126
+ # 这里的 inputs 是 token IDs tensor, shape [1, sequence_length]
127
+ input_len = inputs.shape[1]
128
+
129
+ # 4. 获取 EOT (End of Text) token IDs
130
+ # Qwen 系列使用 <|im_end|> (151645) 和/或 <|endoftext|> (151643)
131
+ # tokenizer.eos_token_id 应该已经正确设置 (来自 Qwen1.5 基础模型)
132
+ eot_ids = [tokenizer.eos_token_id]
133
+
134
+ # 额外检查 <|im_end|>,确保它也在停止列表中
135
+ im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
136
+ if im_end_token_id != tokenizer.unk_token_id and im_end_token_id not in eot_ids:
137
+ eot_ids.append(im_end_token_id)
138
+
139
+ print(f"使用 EOT token IDs: {eot_ids} (eos_token_id: {tokenizer.eos_token_id})")
140
+
141
+ # 5. 生成
142
+ # 使用 eos_token_id 列表来确保模型在 <|im_end|> 处停止
143
+ outputs = model.generate(
144
+ input_ids=inputs,
145
+ max_new_tokens=2048,
146
+ eos_token_id=eot_ids, # <--- 关键:告诉模型何时停止
147
+ pad_token_id=tokenizer.eos_token_id # 避免 HF warning
148
+ )
149
+
150
+ # 6. 从输出的 token 序列中,只选择新生成的部分
151
+ # outputs[0] 包含 (输入 + 生成) 的所有 tokens
152
+ new_tokens = outputs[0][input_len:]
153
+
154
+ # 7. 只解码新生成的 token
155
+ # skip_special_tokens=True 会移除任何 <|im_end|>
156
+ result = tokenizer.decode(new_tokens, skip_special_tokens=True)
157
+
158
+ print(f"生成结果 (仅新内容): {result}")
159
+ return result
160
+ except Exception as e:
161
+ print(f"推理时出错: {e}")
162
+ return f"错误: {e}"
163
+
164
+ # ----------------------------------------------------
165
+
166
+ example = """You are a senior SQL engineer. Your task is to generate a single, correct, and executable SQL query to answer the user's question based on the provided database context.
167
+
168
+ ## INSTRUCTIONS
169
+ 1. **Backend Adherence**: The query MUST be written for the `clickhouse` database backend. This is a strict requirement.
170
+ 2. **Follow Special Notes**: You MUST strictly follow all syntax, functions, or constraints described in the [Database Backend Notes]. Pay extremely close attention to this section, as it contains critical, non-standard rules.
171
+ 3. **Schema Integrity**: The query MUST ONLY use the tables and columns provided in the [Database Schema]. Do not invent or guess table or column names.
172
+ 4. **Answer the Question**: The query must directly and accurately answer the [Natural Language Question].
173
+ 5. **Output Format**: Enclose the final SQL query in a single Markdown code block formatted for SQL (` ```sql ... ``` `).
174
+ 6. **Embedding Match**: If the [EMBEDDING_MODEL_NAME] parameter is a valid string (e.g., 'intfloat/E5-Mistral-7B-Instruct'), you MUST generate a query that includes the WHERE [EMBEDDING_COLUMN_NAME] MATCH lembed(...) clause for vector search. Otherwise, if embedding model name below the [EMBEDDING MODEL NAME] is None, , you MUST generate a standard SQL query that OMITS the entire MATCH lembed(...) clause. The query should not perform any vector search.
175
+ 7. **Embedding Name**: If a value is provided for the parameter `[EMBEDDING_MODEL_NAME]`, your generated query must contain a `lembed` function call. The first parameter to the `lembed` function MUST be the exact value of `[EMBEDDING_MODEL_NAME]`, formatted as a string literal (enclosed in single quotes). For example, if `[EMBEDDING_MODEL_NAME]` is `laion/CLIP-ViT-B-32-laion2B-s34B-b79K`, the generated SQL must include `MATCH lembed('laion/CLIP-ViT-B-32-laion2B-s34B-b79K', ...)`.
176
+
177
+ ## DATABASE CONTEXT
178
+
179
+ [DATABASE BACKEND]:
180
+ clickhouse
181
+
182
+ [DATABASE SCHEMA]:
183
+ CREATE TABLE CAMPAIGN_RESULTS (
184
+ `result_id` Nullable(Int64),
185
+ `campaign_id` Nullable(Int64),
186
+ `territory_id` Nullable(Int64),
187
+ `menu_item_id` Nullable(Int64),
188
+ `sales_increase_percentage` Nullable(Float64),
189
+ `customer_engagement_score` Nullable(Float64),
190
+ `feedback_improvement` Nullable(Float64)
191
+ );
192
+ CREATE TABLE CUSTOMERS (
193
+ `customer_id` Nullable(Int64),
194
+ `customer_name` Nullable(String),
195
+ `email` Nullable(String),
196
+ `phone_number` Nullable(String),
197
+ `loyalty_points` Nullable(Int64)
198
+ );
199
+ CREATE TABLE CUSTOMER_FEEDBACK (
200
+ `feedback_id` Nullable(Int64),
201
+ `menu_item_id` Nullable(Int64),
202
+ `customer_id` Nullable(Int64),
203
+ `feedback_date` Nullable(String),
204
+ `rating` Nullable(Float64),
205
+ `comments` Nullable(String),
206
+ `feedback_type` Nullable(String),
207
+ `comments_embedding` Array(Float32)
208
+ );
209
+ CREATE TABLE INGREDIENTS (
210
+ `ingredient_id` Nullable(Int64),
211
+ `name` Nullable(String),
212
+ `description` Nullable(String),
213
+ `supplier_id` Nullable(Int64),
214
+ `cost_per_unit` Nullable(Float64),
215
+ `description_embedding` Array(Float32)
216
+ );
217
+ CREATE TABLE MARKETING_CAMPAIGNS (
218
+ `campaign_id` Nullable(Int64),
219
+ `campaign_name` Nullable(String),
220
+ `start_date` Nullable(String),
221
+ `end_date` Nullable(String),
222
+ `budget` Nullable(Float64),
223
+ `objective` Nullable(String),
224
+ `territory_id` Nullable(Int64)
225
+ );
226
+ CREATE TABLE MENU_CATEGORIES (
227
+ `category_id` Nullable(Int64),
228
+ `category_name` Nullable(String),
229
+ `description` Nullable(String),
230
+ `parent_category_id` Nullable(Int64),
231
+ `description_embedding` Array(Float32)
232
+ );
233
+ CREATE TABLE MENU_ITEMS (
234
+ `menu_item_id` Nullable(Int64),
235
+ `territory_id` Nullable(Int64),
236
+ `name` Nullable(String),
237
+ `price_inr` Nullable(Float64),
238
+ `price_usd` Nullable(Float64),
239
+ `price_eur` Nullable(Float64),
240
+ `category_id` Nullable(Int64),
241
+ `menu_type` Nullable(String),
242
+ `calories` Nullable(Int64),
243
+ `is_vegetarian` Nullable(Int64),
244
+ `promotion_id` Nullable(Int64)
245
+ );
246
+ CREATE TABLE MENU_ITEM_INGREDIENTS (
247
+ `menu_item_id` Nullable(Int64),
248
+ `ingredient_id` Nullable(Int64),
249
+ `quantity` Nullable(Float64),
250
+ `unit_of_measurement` Nullable(String)
251
+ );
252
+ CREATE TABLE PRICING_STRATEGIES (
253
+ `strategy_id` Nullable(Int64),
254
+ `strategy_name` Nullable(String),
255
+ `description` Nullable(String),
256
+ `territory_id` Nullable(Int64),
257
+ `effective_date` Nullable(String),
258
+ `end_date` Nullable(String),
259
+ `description_embedding` Array(Float32)
260
+ );
261
+ CREATE TABLE PROMOTIONS (
262
+ `promotion_id` Nullable(Int64),
263
+ `promotion_name` Nullable(String),
264
+ `start_date` Nullable(String),
265
+ `end_date` Nullable(String),
266
+ `discount_percentage` Nullable(Float64),
267
+ `category_id` Nullable(Int64),
268
+ `territory_id` Nullable(Int64)
269
+ );
270
+ CREATE TABLE SALES_DATA (
271
+ `sale_id` Nullable(Int64),
272
+ `menu_item_id` Nullable(Int64),
273
+ `territory_id` Nullable(Int64),
274
+ `sale_date` Nullable(String),
275
+ `quantity_sold` Nullable(Int64),
276
+ `total_revenue` Nullable(Float64),
277
+ `discount_applied` Nullable(Float64),
278
+ `customer_id` Nullable(Int64)
279
+ );
280
+ CREATE TABLE SALES_FORECAST (
281
+ `forecast_id` Nullable(Int64),
282
+ `menu_item_id` Nullable(Int64),
283
+ `territory_id` Nullable(Int64),
284
+ `forecast_date` Nullable(String),
285
+ `forecast_quantity` Nullable(Int64),
286
+ `forecast_revenue` Nullable(Float64),
287
+ `prediction_accuracy` Nullable(Float64)
288
+ );
289
+ CREATE TABLE SUPPLIERS (
290
+ `supplier_id` Nullable(Int64),
291
+ `supplier_name` Nullable(String),
292
+ `contact_email` Nullable(String),
293
+ `phone_number` Nullable(String),
294
+ `address` Nullable(String)
295
+ );
296
+ CREATE TABLE TERRITORIES (
297
+ `territory_id` Nullable(Int64),
298
+ `territory_name` Nullable(String),
299
+ `region` Nullable(String),
300
+ `contact_email` Nullable(String),
301
+ `local_tax_rate` Nullable(Float64),
302
+ `currency_code` Nullable(String)
303
+ );
304
+ CREATE TABLE USERS (
305
+ `user_id` Nullable(Int64),
306
+ `user_name` Nullable(String),
307
+ `email` Nullable(String),
308
+ `role_id` Nullable(Int64),
309
+ `territory_id` Nullable(Int64)
310
+ );
311
+ CREATE TABLE USER_ROLES (
312
+ `role_id` Nullable(Int64),
313
+ `role_name` Nullable(String),
314
+ `description` Nullable(String),
315
+ `permissions` Nullable(String),
316
+ `description_embedding` Array(Float32)
317
+ );
318
+
319
+ [DATABASE BACKEND NOTES]:
320
+ There are a few requirements you should comply with in addition:
321
+ 1. When generating SQL queries, you should prioritize utilizing K-Nearest Neighbor (KNN) searches whenever contextually appropriate. However, you must avoid unnecessary/forced KNN implementations for:
322
+ -- Traditional relational data queries (especially for columns like: id, age, price).
323
+ -- Cases where standard SQL operators (equality, range, or aggregation functions) are more efficient and semantically appropriate.
324
+ 2. Only columns with a vector type (like: Array(Float32)) support KNN queries. The names of these vector columns often end with "_embedding". You can perform KNN searches when the column name you need to query ends with "_embedding" or is otherwise identified as a vector column.
325
+ 3. In ClickHouse, vector similarity search is performed using distance functions. You must explicitly calculate the distance in the SELECT clause using a function like L2Distance and give it an alias, typically "AS distance". This distance alias will not be implicitly generated.
326
+ 4. The lembed function is used to transform a string into a semantic vector. This function should be used within a WITH clause to define the reference vector. The lembed function has two parameters: the first is the name of the embedding model used (default value: 'intfloat/E5-Mistral-7B-Instruct'), and the second is the string content to embed. The resulting vector should be given an alias in the WITH clause.
327
+ 5. You must generate plausible and semantically relevant words or sentences for the second parameter of the lembed function based on the column's name, type, and comment. For example, if a column is named product_description_embedding and its comment is "Embedding of the product's features and marketing text", you could generate text like "durable and waterproof outdoor adventure camera".
328
+ 6. Every KNN search query MUST conclude with "ORDER BY distance LIMIT N" to retrieve the top-N most similar results. The LIMIT clause is mandatory for performing a KNN search and ensuring predictable performance.
329
+ 7. When combining a vector search with JOIN operations, the standard WHERE clause should be used to apply filters from any of the joined tables. The ORDER BY distance LIMIT N clause is applied after all filtering and joins are resolved.
330
+ 8. A SELECT statement should typically be ordered by a single distance calculation to perform one primary KNN search. However, subqueries can perform their own independent KNN searches, each with its own WITH clause, distance calculation, and ORDER BY distance LIMIT N clause.
331
+
332
+ ## Example of a ClickHouse KNN Query
333
+ DB Schema: Some table on articles with a column content_embedding Array(Float32).
334
+ Query Task: Identify the article ID of the single most relevant article discussing innovative algorithms in graph theory.
335
+ Generated SQL:
336
+ ```sql
337
+ WITH\n lembed('intfloat/E5-Mistral-7B-Instruct', 'innovative algorithms in graph theory.') AS ref_vec_0\n\nSELECT id, L2Distance(articles.abstract_embedding, ref_vec_0) AS distance\nFROM articles\nORDER BY distance\nLIMIT 1;
338
+ ```
339
+
340
+ [EMBEDDING MODEL NAME]:
341
+ intfloat/E5-Mistral-7B-Instruct
342
+
343
+ ## NATURAL LANGUAGE QUESTION
344
+ In vector searches, the `MATCH` operator performs an approximate nearest neighbor (ANN) search, which identifies items based on their similarity to a given vector. The `lembed()` function is used to convert text phrases into vector representations using a specific model, in this case, `intfloat/E5-Mistral-7B-Instruct`. This helps in finding items that align closely with the concept of "Popular menu items based on sales." The parameter `k = 5` specifies that only the top 5 categories, which are most similar in terms of the embedding, should be considered. The similarity is determined by calculating the Euclidean distance between vectors, where a smaller distance indicates higher similarity.
345
+ Can you unveil the crown jewel of our vegetarian delights, the one that has soared to the top of the sales charts from the elite circle of our most cherished categories this year?
346
+
347
+ Let's think step by step!
348
+ """
349
+
350
+ examples = [example]
351
+
352
+ # 6. 创建 Gradio 界面 (原第5步)
353
+ iface = gr.Interface(
354
+ fn=inference,
355
+ # 增加了行数以便容纳 Schema
356
+ inputs=gr.Textbox(lines=10, label="输入查询 (Input Query)"),
357
+ outputs=gr.Textbox(lines=10, label="模型输出 (Model Output)"),
358
+ title="UniVectorSQL-7B-LoRA 推理",
359
+ # description="这是一个 Text-to-SQL 模型。请输入您的问题 (Question) 和数据库模式 (Schema)。点击下方示例尝试。",
360
+ examples=examples # <--- 添加示例
361
+ )
362
+ # ---------------------
363
 
364
+ print("启动 Gradio 界面...")
365
+ # Hugging Face 或 ModelScope Space 中,share=True 不是必需的
366
  iface.launch()