Paul720810 commited on
Commit
954de7f
·
verified ·
1 Parent(s): 7b1b963

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -129
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import requests
3
  import json
4
  import os
5
- from datasets import load_dataset, Dataset
6
  from sentence_transformers import SentenceTransformer, util
7
  import torch
8
  from huggingface_hub import hf_hub_download
@@ -11,24 +11,23 @@ import re
11
  # --- 配置區 ---
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
14
- # 使用正確的模型名稱(7B版本更適合免費使用)
15
- LLM_API_URL = "https://api-inference.huggingface.co/models/codellama/CodeLlama-7b-hf"
16
- SIMILARITY_THRESHOLD = 0.90
17
 
18
  print("--- [1/5] 開始初始化應用 ---")
19
 
20
  # --- 1. 載入知識庫 ---
21
- qa_dataset = None
22
- schema_data = {}
23
  questions = []
24
  sql_answers = []
 
25
 
26
  try:
27
  print(f"--- [2/5] 正在從 '{DATASET_REPO_ID}' 載入知識庫... ---")
28
  raw_dataset = load_dataset(DATASET_REPO_ID, token=HF_TOKEN)['train']
29
 
30
- # 解析新的 messages 格式
31
- print("--- > 檢測到 'messages' 格式,正在解析...")
32
 
33
  for item in raw_dataset:
34
  try:
@@ -36,80 +35,48 @@ try:
36
  user_content = item['messages'][0]['content']
37
  assistant_content = item['messages'][1]['content']
38
 
39
- # 從用戶消息中提取問題
40
  question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
41
- if question_match:
42
- question = question_match.group(1).strip()
43
- else:
44
- # 如果沒有找到指令,使用整個內容
45
- question = user_content
46
 
47
- # 從助手消息中提取SQL
48
  sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
49
  if sql_match:
50
  sql_query = sql_match.group(1).strip()
51
- # 清理SQL語句
52
- sql_query = re.sub(r'^sql\s*', '', sql_query) # 移除開頭的sql
53
- sql_query = re.sub(r'```sql|```', '', sql_query).strip() # 移除代碼塊標記
54
  else:
55
  sql_query = assistant_content
56
 
57
  questions.append(question)
58
  sql_answers.append(sql_query)
59
 
60
- except (KeyError, IndexError, TypeError) as e:
61
  print(f"解析錯誤,跳過該條目: {e}")
62
  continue
63
 
64
- # 創建問答數據集
65
- if questions:
66
- qa_dataset = Dataset.from_dict({
67
- 'question': questions,
68
- 'sql': sql_answers
69
- })
70
- else:
71
- raise ValueError("沒有成功解析出任何問答對")
72
 
73
- # 載入並解析 Schema JSON
74
- schema_file_path = "sqlite_schema_FULL.json"
75
  try:
76
- hf_hub_download(repo_id=DATASET_REPO_ID, filename=schema_file_path,
77
- repo_type='dataset', local_dir='.', token=HF_TOKEN)
78
-
 
 
 
79
  with open(schema_file_path, 'r', encoding='utf-8') as f:
80
  schema_data = json.load(f)
81
  except Exception as e:
82
  print(f"警告: 無法載入Schema文件: {e}")
83
- schema_data = {}
84
-
85
- print(f"--- > 成功解析 {len(questions)} 條問答範例。 ---")
86
 
87
  except Exception as e:
88
- print(f"!!! 錯誤: 處理Dataset時發生問題: {e}")
89
- # 創建備用數據集
90
  questions = ["示例問題"]
91
- sql_answers = ["SELECT '請檢查數據集格式' AS error;"]
92
- qa_dataset = Dataset.from_dict({"question": questions, "sql": sql_answers})
93
 
94
- # --- 2. 構建 DDL 和初始化檢索模型 ---
95
- def load_schema_as_ddl(schema_dict: dict) -> str:
96
- ddl_string = "/* 數據庫結構 */\n"
97
- for table_name, columns in schema_dict.items():
98
- if not isinstance(columns, list):
99
- continue
100
- ddl_string += f"CREATE TABLE `{table_name}` (\n"
101
- ddl_cols = []
102
- for col in columns:
103
- col_name = col.get('name', 'unknown')
104
- col_type = col.get('type', 'TEXT')
105
- col_desc = col.get('description', '')
106
- ddl_cols.append(f" `{col_name}` {col_type} -- {col_desc}")
107
- ddl_string += ",\n".join(ddl_cols) + "\n);\n\n"
108
- return ddl_string
109
-
110
- SCHEMA_DDL = load_schema_as_ddl(schema_data)
111
-
112
- print("--- [3/5] 正在載入句向量模型 (all-MiniLM-L6-v2)... ---")
113
  embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
114
 
115
  # 計算問題向量
@@ -118,10 +85,28 @@ if questions:
118
  question_embeddings = embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
119
  print("--- > 向量計算完成! ---")
120
  else:
121
- print("--- [4/5] 警告:沒有可用的問題來計算向量。 ---")
122
  question_embeddings = torch.Tensor([])
123
 
124
- # --- 3. 混合系統核心邏輯 ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def get_sql_query(user_question: str):
126
  if not user_question:
127
  return "請輸入您的問題。", "日誌:用戶未輸入問題。"
@@ -129,50 +114,43 @@ def get_sql_query(user_question: str):
129
  log_messages = []
130
 
131
  # 檢索相似問題
132
- if len(questions) > 0:
133
- question_embedding = embedder.encode(user_question, convert_to_tensor=True)
134
- hits = util.semantic_search(question_embedding, question_embeddings, top_k=3)
135
-
136
- if hits and hits[0]:
137
- most_similar_hit = hits[0][0]
138
- similarity_score = most_similar_hit['score']
139
- similar_question = questions[most_similar_hit['corpus_id']]
140
-
141
- log_messages.append(f"檢索到相似問題: '{similar_question}' (相似度: {similarity_score:.4f})")
142
 
143
- if similarity_score > SIMILARITY_THRESHOLD:
144
- sql_result = sql_answers[most_similar_hit['corpus_id']]
145
- log_messages.append(f"相似度 > {SIMILARITY_THRESHOLD},直接返回預先SQL")
146
- return sql_result, "\n".join(log_messages)
147
- else:
148
- log_messages.append("檢索失敗:找不到相似問題")
 
 
 
 
 
 
 
 
 
 
 
 
149
  else:
150
  log_messages.append("知識庫為空,跳過檢索")
151
 
152
  # LLM生成模式
153
  log_messages.append("進入LLM生成模式...")
154
 
155
- # 構建示例上下文
156
- examples_context = ""
157
- if 'hits' in locals() and hits and hits[0]:
158
- for i, hit in enumerate(hits[0][:2]):
159
- examples_context += f"問題: {questions[hit['corpus_id']]}\nSQL: {sql_answers[hit['corpus_id']]}\n\n"
160
-
161
- # 構建提示詞
162
- prompt = f"""你是一個SQL專家。請根據數據庫結構生成SQL查詢。
163
-
164
- 數據庫結構:
165
- {SCHEMA_DDL}
166
-
167
- 參考示例:
168
- {examples_context}
169
 
170
- 請為以下問題生成SQL查詢:
171
- {user_question}
172
 
173
- 只輸出SQL語句,不要其他內容:
174
 
175
- """
176
 
177
  log_messages.append("正在請求雲端LLM...")
178
 
@@ -180,7 +158,7 @@ def get_sql_query(user_question: str):
180
  payload = {
181
  "inputs": prompt,
182
  "parameters": {
183
- "max_new_tokens": 300,
184
  "temperature": 0.1,
185
  "do_sample": False
186
  }
@@ -194,74 +172,78 @@ def get_sql_query(user_question: str):
194
  if isinstance(result, list) and len(result) > 0:
195
  generated_text = result[0]['generated_text'].strip()
196
 
197
- # 清理輸出,只保留SQL
198
- if "```sql" in generated_text:
199
- generated_text = generated_text.split("```sql")[1].split("```")[0].strip()
200
- elif "```" in generated_text:
201
- generated_text = generated_text.split("```")[1].strip() if len(generated_text.split("```")) > 2 else generated_text
202
 
203
  log_messages.append("LLM生成成功!")
204
  return generated_text, "\n".join(log_messages)
205
- else:
206
- raise Exception(f"API返回格式異常: {result}")
207
  else:
208
- raise Exception(f"API錯誤: {response.status_code} - {response.text}")
209
 
210
  except Exception as e:
211
  error_msg = f"LLM API調用失敗: {str(e)}"
212
  log_messages.append(error_msg)
213
 
214
- # 提供備用答案
215
- backup_sql = "SELECT 'AI服務暫時不可用,請稍後重試' AS status;"
216
  return backup_sql, "\n".join(log_messages)
217
 
218
- # --- 4. 創建 Gradio Web 界面 ---
219
- print("--- [5/5] 正在創建 Gradio Web 界面... ---")
220
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
221
- gr.Markdown("# 🚀 智能 Text-to-SQL 系統 (混合模式)")
222
- gr.Markdown("輸入自然語言問題,系統會智能生成SQL查詢")
 
223
 
224
  with gr.Row():
225
  question_input = gr.Textbox(
226
- label="輸入您的問題",
227
- placeholder="例如:查詢去年的銷售總額",
228
- lines=2,
229
- scale=4
230
  )
231
- submit_button = gr.Button("生成SQL", variant="primary", scale=1)
 
 
 
232
 
233
  with gr.Row():
234
  sql_output = gr.Code(
235
- label="生成的 SQL 查詢",
236
  language="sql",
237
- lines=6
238
  )
239
 
240
  with gr.Row():
241
  log_output = gr.Textbox(
242
- label="系統日誌",
243
- lines=4,
244
  interactive=False
245
  )
246
 
247
- submit_button.click(
 
248
  fn=get_sql_query,
249
  inputs=question_input,
250
  outputs=[sql_output, log_output]
251
  )
252
 
 
 
 
 
 
 
 
253
  gr.Examples(
254
  examples=[
255
- "2024年最好的5個客戶以及業績",
256
- "比較2023年跟2024年的業績",
257
- "上週C組完成了幾份報告",
258
- "有沒有快到期的訂單?",
259
- "哪個客戶的付款最不及時?"
260
  ],
261
- inputs=question_input,
262
- label="示例問題"
263
  )
264
 
265
- print("--- 應用準備啟動 ---")
266
  if __name__ == "__main__":
267
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
2
  import requests
3
  import json
4
  import os
5
+ from datasets import load_dataset
6
  from sentence_transformers import SentenceTransformer, util
7
  import torch
8
  from huggingface_hub import hf_hub_download
 
11
  # --- 配置區 ---
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
14
+ # 使用更可靠且免費的模型
15
+ LLM_API_URL = "https://api-inference.huggingface.co/models/microsoft/DialoGPT-large"
16
+ SIMILARITY_THRESHOLD = 0.85 # 降低閾值以提高檢索命中率
17
 
18
  print("--- [1/5] 開始初始化應用 ---")
19
 
20
  # --- 1. 載入知識庫 ---
 
 
21
  questions = []
22
  sql_answers = []
23
+ schema_data = {}
24
 
25
  try:
26
  print(f"--- [2/5] 正在從 '{DATASET_REPO_ID}' 載入知識庫... ---")
27
  raw_dataset = load_dataset(DATASET_REPO_ID, token=HF_TOKEN)['train']
28
 
29
+ # 解析 messages 格式
30
+ print("--- > 解析 messages 格式...")
31
 
32
  for item in raw_dataset:
33
  try:
 
35
  user_content = item['messages'][0]['content']
36
  assistant_content = item['messages'][1]['content']
37
 
38
+ # 提取問題
39
  question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
40
+ question = question_match.group(1).strip() if question_match else user_content
 
 
 
 
41
 
42
+ # 提取SQL
43
  sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
44
  if sql_match:
45
  sql_query = sql_match.group(1).strip()
46
+ sql_query = re.sub(r'^sql\s*', '', sql_query)
47
+ sql_query = re.sub(r'```sql|```', '', sql_query).strip()
 
48
  else:
49
  sql_query = assistant_content
50
 
51
  questions.append(question)
52
  sql_answers.append(sql_query)
53
 
54
+ except Exception as e:
55
  print(f"解析錯誤,跳過該條目: {e}")
56
  continue
57
 
58
+ print(f"--- > 成功解析 {len(questions)} 條問答範例 ---")
 
 
 
 
 
 
 
59
 
60
+ # 載入Schema
 
61
  try:
62
+ schema_file_path = hf_hub_download(
63
+ repo_id=DATASET_REPO_ID,
64
+ filename="sqlite_schema_FULL.json",
65
+ repo_type='dataset',
66
+ token=HF_TOKEN
67
+ )
68
  with open(schema_file_path, 'r', encoding='utf-8') as f:
69
  schema_data = json.load(f)
70
  except Exception as e:
71
  print(f"警告: 無法載入Schema文件: {e}")
 
 
 
72
 
73
  except Exception as e:
74
+ print(f"錯誤: 載入數據集失敗: {e}")
 
75
  questions = ["示例問題"]
76
+ sql_answers = ["SELECT '數據庫連接成功' AS status;"]
 
77
 
78
+ # --- 2. 初始化檢索模型 ---
79
+ print("--- [3/5] 正在載入句向量模型... ---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
81
 
82
  # 計算問題向量
 
85
  question_embeddings = embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
86
  print("--- > 向量計算完成! ---")
87
  else:
88
+ print("--- [4/5] 警告:沒有可用的問題 ---")
89
  question_embeddings = torch.Tensor([])
90
 
91
+ # --- 3. 構建DDL ---
92
+ def build_schema_context(schema_dict):
93
+ if not schema_dict:
94
+ return "/* 無Schema信息 */"
95
+
96
+ context = "/* 數據庫表結構 */\n"
97
+ for table_name, columns in schema_dict.items():
98
+ if isinstance(columns, list):
99
+ context += f"\n-- 表: {table_name}\n"
100
+ for col in columns:
101
+ col_name = col.get('name', 'unknown')
102
+ col_type = col.get('type', 'TEXT')
103
+ col_desc = col.get('description', '')
104
+ context += f"-- {col_name} ({col_type}) - {col_desc}\n"
105
+ return context
106
+
107
+ SCHEMA_CONTEXT = build_schema_context(schema_data)
108
+
109
+ # --- 4. 核心邏輯 ---
110
  def get_sql_query(user_question: str):
111
  if not user_question:
112
  return "請輸入您的問題。", "日誌:用戶未輸入問題。"
 
114
  log_messages = []
115
 
116
  # 檢索相似問題
117
+ if len(questions) > 0 and len(question_embeddings) > 0:
118
+ try:
119
+ question_embedding = embedder.encode(user_question, convert_to_tensor=True)
120
+ hits = util.semantic_search(question_embedding, question_embeddings, top_k=3)
 
 
 
 
 
 
121
 
122
+ if hits and hits[0]:
123
+ most_similar_hit = hits[0][0]
124
+ similarity_score = most_similar_hit['score']
125
+ similar_question = questions[most_similar_hit['corpus_id']]
126
+
127
+ log_messages.append(f"檢索到相似問題: '{similar_question}' (相似度: {similarity_score:.3f})")
128
+
129
+ if similarity_score > SIMILARITY_THRESHOLD:
130
+ sql_result = sql_answers[most_similar_hit['corpus_id']]
131
+ log_messages.append(f"相似度 > {SIMILARITY_THRESHOLD},直接返回預先SQL")
132
+ return sql_result, "\n".join(log_messages)
133
+ else:
134
+ log_messages.append(f"相似度低於閾值 {SIMILARITY_THRESHOLD}")
135
+ else:
136
+ log_messages.append("檢索失敗:找不到相似問題")
137
+
138
+ except Exception as e:
139
+ log_messages.append(f"檢索過程出錯: {e}")
140
  else:
141
  log_messages.append("知識庫為空,跳過檢索")
142
 
143
  # LLM生成模式
144
  log_messages.append("進入LLM生成模式...")
145
 
146
+ # 構建提示詞 - 更簡單的版本
147
+ prompt = f"""請根據以下數據庫結構,為這個問題生成SQL查詢:
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ {SCHEMA_CONTEXT}
 
150
 
151
+ 問題:{user_question}
152
 
153
+ 請只輸出SQL語句:"""
154
 
155
  log_messages.append("正在請求雲端LLM...")
156
 
 
158
  payload = {
159
  "inputs": prompt,
160
  "parameters": {
161
+ "max_new_tokens": 200,
162
  "temperature": 0.1,
163
  "do_sample": False
164
  }
 
172
  if isinstance(result, list) and len(result) > 0:
173
  generated_text = result[0]['generated_text'].strip()
174
 
175
+ # 簡單清理
176
+ generated_text = re.sub(r'^```sql|```$', '', generated_text).strip()
 
 
 
177
 
178
  log_messages.append("LLM生成成功!")
179
  return generated_text, "\n".join(log_messages)
 
 
180
  else:
181
+ raise Exception(f"API錯誤: {response.status_code}")
182
 
183
  except Exception as e:
184
  error_msg = f"LLM API調用失敗: {str(e)}"
185
  log_messages.append(error_msg)
186
 
187
+ # 提供更有用的備用答案
188
+ backup_sql = "SELECT 'AI服務暫時不可用,請稍後再試或聯繫管理員' AS status;"
189
  return backup_sql, "\n".join(log_messages)
190
 
191
+ # --- 5. 創建界面 ---
192
+ print("--- [5/5] 正在創建 Web 界面... ---")
193
+
194
+ with gr.Blocks(title="智能Text-to-SQL系統") as demo:
195
+ gr.Markdown("# 🤖 智能 Text-to-SQL 系統")
196
+ gr.Markdown("輸入自然語言問題,自動生成SQL查詢")
197
 
198
  with gr.Row():
199
  question_input = gr.Textbox(
200
+ label="您的問題",
201
+ placeholder="例如:查詢去年的銷售數據",
202
+ lines=2
 
203
  )
204
+
205
+ with gr.Row():
206
+ submit_btn = gr.Button("生成SQL", variant="primary")
207
+ clear_btn = gr.Button("清除")
208
 
209
  with gr.Row():
210
  sql_output = gr.Code(
211
+ label="生成的SQL",
212
  language="sql",
213
+ lines=5
214
  )
215
 
216
  with gr.Row():
217
  log_output = gr.Textbox(
218
+ label="執行日誌",
219
+ lines=3,
220
  interactive=False
221
  )
222
 
223
+ # 綁定事件
224
+ submit_btn.click(
225
  fn=get_sql_query,
226
  inputs=question_input,
227
  outputs=[sql_output, log_output]
228
  )
229
 
230
+ clear_btn.click(
231
+ fn=lambda: ["", ""],
232
+ inputs=[],
233
+ outputs=[sql_output, log_output]
234
+ )
235
+
236
+ # 示例
237
  gr.Examples(
238
  examples=[
239
+ "查詢2024年銷售額最高的產品",
240
+ "顯示最近30天的訂單",
241
+ "統計每個客戶的訂單數量",
242
+ "找出庫存不足的商品"
 
243
  ],
244
+ inputs=question_input
 
245
  )
246
 
247
+ print("--- 應用啟動完成 ---")
248
  if __name__ == "__main__":
249
+ demo.launch(server_name="0.0.0.0", server_port=7860)