Paul720810 commited on
Commit
53f5b49
·
verified ·
1 Parent(s): 7e97ca2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -172
app.py CHANGED
@@ -3,8 +3,6 @@ import requests
3
  import json
4
  import os
5
  import re
6
- import sqlite3
7
- import pandas as pd
8
  from datetime import datetime
9
  from datasets import load_dataset
10
  from sentence_transformers import SentenceTransformer, util
@@ -34,12 +32,14 @@ def get_current_time():
34
 
35
  def validate_sql(sql_query: str) -> Dict:
36
  """驗證SQL語句的安全性"""
 
 
 
37
  security_issues = []
 
38
 
39
  # 檢查危險操作
40
  dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
41
- sql_upper = sql_query.upper()
42
-
43
  for keyword in dangerous_keywords:
44
  if f" {keyword} " in sql_upper:
45
  security_issues.append(f"發現危險操作: {keyword}")
@@ -57,28 +57,54 @@ def validate_sql(sql_query: str) -> Dict:
57
  "is_safe": len([i for i in security_issues if '危險' in i]) == 0
58
  }
59
 
60
- def repair_sql(sql_query: str) -> str:
61
- """修復有問題的SQL語句"""
62
- if not sql_query or sql_query.strip() == "":
63
- return "SELECT 'SQL語句為空' AS error;"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- # 清理SQL
66
- sql_clean = re.sub(r'^```sql|```$', '', sql_query).strip()
67
 
68
- # 檢查是否已經是完整SQL
69
- if "SELECT" in sql_clean.upper() and "FROM" in sql_clean.upper():
70
- return sql_clean
 
 
 
71
 
72
- # 如果只有SELECT部分
73
- if "SELECT" in sql_clean.upper() and "FROM" not in sql_clean.upper():
74
- return sql_clean + " FROM appropriate_table WHERE 1=1;"
75
 
76
- # 如果只有FROM部分
77
- if "FROM" in sql_clean.upper() and "SELECT" not in sql_clean.upper():
78
- return "SELECT * " + sql_clean
 
 
 
 
 
 
 
79
 
80
- # 如果什麼都沒有,提供默認查詢
81
- return "SELECT '請檢查SQL語法' AS status;"
82
 
83
  # ==================== 數據加載模塊 ====================
84
  class DataLoader:
@@ -95,6 +121,9 @@ class DataLoader:
95
  raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
96
 
97
  print("正在解析 messages 格式...")
 
 
 
98
  for item in raw_dataset:
99
  try:
100
  if 'messages' in item and len(item['messages']) >= 2:
@@ -111,22 +140,24 @@ class DataLoader:
111
  sql_query = sql_match.group(1).strip()
112
  sql_query = re.sub(r'^sql\s*', '', sql_query)
113
  sql_query = re.sub(r'```sql|```', '', sql_query).strip()
114
-
115
- # 驗證並修復SQL
116
- validation = validate_sql(sql_query)
117
- if not validation["valid"]:
118
- print(f"發現有問題的SQL,將進行修復: {sql_query}")
119
- sql_query = repair_sql(sql_query)
120
  else:
121
- sql_query = repair_sql(assistant_content)
 
 
 
 
 
 
 
122
 
123
  self.questions.append(question)
124
  self.sql_answers.append(sql_query)
 
125
 
126
  except Exception as e:
127
  continue
128
 
129
- print(f"成功解析 {len(self.questions)} 條問答範例")
130
  return True
131
 
132
  except Exception as e:
@@ -169,73 +200,12 @@ class DataLoader:
169
  context += f"-- {col_name} ({col_type}) - {col_desc}\n"
170
  return context
171
 
172
- # ==================== LLM模塊 ====================
173
- class LLMClient:
174
- def __init__(self, hf_token: str):
175
- self.hf_token = hf_token
176
-
177
- def call_llm_api(self, prompt: str, model_urls: List[str] = LLM_MODELS) -> Optional[str]:
178
- """調用LLM API"""
179
- headers = {"Authorization": f"Bearer {self.hf_token}"}
180
- payload = {
181
- "inputs": prompt,
182
- "parameters": {
183
- "max_new_tokens": 200,
184
- "temperature": 0.1,
185
- "do_sample": False
186
- }
187
- }
188
-
189
- for model_url in model_urls:
190
- try:
191
- response = requests.post(model_url, headers=headers, json=payload, timeout=20)
192
-
193
- if response.status_code == 200:
194
- result = response.json()
195
- if isinstance(result, list) and len(result) > 0:
196
- generated_text = result[0]['generated_text'].strip()
197
- generated_text = re.sub(r'^```sql|```$', '', generated_text).strip()
198
- return generated_text
199
-
200
- except Exception as e:
201
- continue
202
-
203
- return None
204
-
205
- # ==================== 檢索模塊 ====================
206
- class RetrievalSystem:
207
- def __init__(self):
208
- self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
209
- self.question_embeddings = None
210
-
211
- def compute_embeddings(self, questions: List[str]) -> None:
212
- """計算問題向量"""
213
- if questions:
214
- print(f"正在為 {len(questions)} 個問題計算向量...")
215
- self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True, show_progress_bar=False)
216
- print("向量計算完成")
217
- else:
218
- self.question_embeddings = torch.Tensor([])
219
-
220
- def retrieve_similar(self, user_question: str, top_k: int = 3) -> List[Dict]:
221
- """檢索相似問題"""
222
- if self.question_embeddings is None or len(self.question_embeddings) == 0:
223
- return []
224
-
225
- try:
226
- question_embedding = self.embedder.encode(user_question, convert_to_tensor=True)
227
- hits = util.semantic_search(question_embedding, self.question_embeddings, top_k=top_k)
228
- return hits[0] if hits and hits[0] else []
229
- except Exception as e:
230
- print(f"檢索失敗: {e}")
231
- return []
232
-
233
  # ==================== 主系統 ====================
234
  class TextToSQLSystem:
235
  def __init__(self, hf_token: str):
236
  self.hf_token = hf_token
237
  self.data_loader = DataLoader(hf_token)
238
- self.llm_client = LLMClient(hf_token)
239
  self.retrieval_system = RetrievalSystem()
240
 
241
  self.initialize_system()
@@ -251,7 +221,12 @@ class TextToSQLSystem:
251
 
252
  print("系統初始化完成")
253
  print(f"可用問題數量: {len(self.data_loader.questions)}")
254
- print(f"Schema表數量: {len(self.data_loader.schema_data)}")
 
 
 
 
 
255
 
256
  def generate_sql(self, user_question: str) -> Tuple[str, str]:
257
  """生成SQL查詢"""
@@ -274,13 +249,23 @@ class TextToSQLSystem:
274
  log_messages.append(f"📊 相似度: {similarity_score:.3f}")
275
 
276
  if similarity_score > SIMILARITY_THRESHOLD:
277
- # 驗證並可能修復SQL
278
  validation = validate_sql(original_sql)
279
  if not validation["valid"]:
280
  log_messages.append(f"⚠️ 原始SQL有問題: {', '.join(validation['issues'])}")
281
- log_messages.append("🛠️ 正在修復SQL...")
282
- repaired_sql = repair_sql(original_sql)
 
 
283
  log_messages.append(f"✅ 修復完成")
 
 
 
 
 
 
 
 
284
  return repaired_sql, "\n".join(log_messages)
285
  else:
286
  log_messages.append(f"✅ 相似度 > {SIMILARITY_THRESHOLD},直接返回")
@@ -291,16 +276,25 @@ class TextToSQLSystem:
291
  # 2. LLM生成模式
292
  log_messages.append("🤖 進入LLM生成模式...")
293
 
294
- prompt = self.build_llm_prompt(user_question)
295
- generated_sql = self.llm_client.call_llm_api(prompt)
 
 
 
 
 
 
 
 
 
 
296
 
297
  if generated_sql:
298
  # 驗證生成的SQL
299
  validation = validate_sql(generated_sql)
300
  if not validation["valid"]:
301
  log_messages.append(f"⚠️ LLM生成的SQL有問題: {', '.join(validation['issues'])}")
302
- log_messages.append("🛠️ 正在修復SQL...")
303
- generated_sql = repair_sql(generated_sql)
304
 
305
  log_messages.append("✅ SQL生成完成")
306
  return generated_sql, "\n".join(log_messages)
@@ -310,34 +304,74 @@ class TextToSQLSystem:
310
  backup_sql = self.generate_backup_sql(user_question)
311
  return backup_sql, "\n".join(log_messages)
312
 
313
- def build_llm_prompt(self, user_question: str) -> str:
314
- """構建LLM提示詞"""
315
- return f"""你是一個SQL專家。請根據以下數據庫結構生成SQL查詢。
316
-
317
- {self.schema_context}
318
-
319
- 請為以下問題生成準確的SQL查詢:
320
- {user_question}
321
-
322
- 要求:
323
- 1. 只輸出SQL語句
324
- 2. 必須包含SELECT和FROM
325
- 3. 使用正確的語法
326
-
327
- SQL查詢:"""
328
-
329
  def generate_backup_sql(self, user_question: str) -> str:
330
  """生成備用SQL"""
331
  user_question_lower = user_question.lower()
332
 
333
- if any(kw in user_question_lower for kw in ['銷售', '業績', '金額', '收入']):
 
 
 
334
  return "SELECT product_name, SUM(sales_amount) as total_sales FROM sales GROUP BY product_name ORDER BY total_sales DESC LIMIT 10;"
335
- elif any(kw in user_question_lower for kw in ['客戶', '買家', '用戶']):
 
336
  return "SELECT customer_name, COUNT(*) as order_count FROM orders GROUP BY customer_name ORDER BY order_count DESC;"
337
- elif any(kw in user_question_lower for kw in ['時間', '日期', '最近', '月份']):
 
338
  return "SELECT strftime('%Y-%m', order_date) as month, COUNT(*) as orders FROM orders GROUP BY month ORDER BY month DESC;"
 
339
  else:
340
- return "SELECT '請重試或提供更詳細的問題' AS status;"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
  # ==================== 初始化系統 ====================
343
  print("正在初始化Text-to-SQL系統...")
@@ -345,79 +379,31 @@ text_to_sql_system = TextToSQLSystem(HF_TOKEN)
345
 
346
  # ==================== Gradio界面 ====================
347
  def process_query(user_question: str) -> Tuple[str, str]:
348
- """處理用戶查詢"""
349
  sql_result, log_message = text_to_sql_system.generate_sql(user_question)
350
-
351
- # 最終驗證
352
  final_validation = validate_sql(sql_result)
353
- debug_info = ""
354
 
355
  if not final_validation["valid"]:
356
  debug_info = "❌ 最終SQL驗證失敗:\n" + "\n".join(final_validation["issues"])
357
- debug_info += "\n🛠️ 已嘗試自動修復,但仍存在问题"
358
  else:
359
  debug_info = "✅ 最終SQL驗證通過"
360
- if final_validation["issues"]:
361
- debug_info += "\nℹ️ 提示: " + ", ".join(final_validation["issues"])
362
 
363
  return sql_result, debug_info, log_message
364
 
365
- # 創建界面
366
  with gr.Blocks(title="智能Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
367
-
368
  gr.Markdown("# 🚀 智能 Text-to-SQL 系統")
369
- gr.Markdown("輸入自然語言問題,自動生成並驗證SQL查詢")
370
-
371
- with gr.Row():
372
- question_input = gr.Textbox(
373
- label="📝 您的問題",
374
- placeholder="例如:查詢2024年銷售額最高的產品",
375
- lines=2
376
- )
377
 
378
  with gr.Row():
 
379
  submit_btn = gr.Button("🚀 生成SQL", variant="primary")
380
- clear_btn = gr.Button("🗑️ 清除", variant="secondary")
381
 
382
  with gr.Row():
383
- sql_output = gr.Code(
384
- label="📊 生成的SQL",
385
- language="sql",
386
- lines=6
387
- )
388
 
389
  with gr.Row():
390
- debug_output = gr.Textbox(
391
- label="🔍 SQL驗證信息",
392
- lines=3,
393
- interactive=False
394
- )
395
 
396
- with gr.Row():
397
- log_output = gr.Textbox(
398
- label="📋 執行日誌",
399
- lines=4,
400
- interactive=False
401
- )
402
-
403
- # 事件處理
404
- submit_btn.click(
405
- fn=process_query,
406
- inputs=question_input,
407
- outputs=[sql_output, debug_output, log_output]
408
- )
409
-
410
- clear_btn.click(
411
- fn=lambda: ["", "", ""],
412
- inputs=[],
413
- outputs=[sql_output, debug_output, log_output]
414
- )
415
 
416
- # ==================== 啟動應用 ====================
417
  if __name__ == "__main__":
418
- print("=" * 50)
419
- print("🌐 啟動Gradio Web界面...")
420
- print("📍 本地訪問: http://localhost:7860")
421
- print("=" * 50)
422
-
423
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
3
  import json
4
  import os
5
  import re
 
 
6
  from datetime import datetime
7
  from datasets import load_dataset
8
  from sentence_transformers import SentenceTransformer, util
 
32
 
33
  def validate_sql(sql_query: str) -> Dict:
34
  """驗證SQL語句的安全性"""
35
+ if not sql_query or sql_query.strip() == "":
36
+ return {"valid": False, "issues": ["SQL語句為空"], "is_safe": False}
37
+
38
  security_issues = []
39
+ sql_upper = sql_query.upper()
40
 
41
  # 檢查危險操作
42
  dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
 
 
43
  for keyword in dangerous_keywords:
44
  if f" {keyword} " in sql_upper:
45
  security_issues.append(f"發現危險操作: {keyword}")
 
57
  "is_safe": len([i for i in security_issues if '危險' in i]) == 0
58
  }
59
 
60
+ def intelligent_sql_repair(original_sql: str, user_question: str, similar_question: str) -> str:
61
+ """智能修復SQL語句"""
62
+ if not original_sql or original_sql.strip() == "":
63
+ # 根據問題內容生成有意義的SQL
64
+ user_question_lower = user_question.lower()
65
+ similar_question_lower = similar_question.lower()
66
+
67
+ # 分析問題類型
68
+ if any(kw in user_question_lower for kw in ['報告', '完成', '份']):
69
+ return "SELECT strftime('%Y-%m', completion_date) as month, COUNT(*) as report_count FROM reports WHERE strftime('%Y', completion_date) = '2023' GROUP BY month ORDER BY month;"
70
+
71
+ elif any(kw in user_question_lower for kw in ['銷售', '業績', '金額']):
72
+ return "SELECT product_name, SUM(sales_amount) as total_sales FROM sales GROUP BY product_name ORDER BY total_sales DESC;"
73
+
74
+ elif any(kw in user_question_lower for kw in ['客戶', '買家']):
75
+ return "SELECT customer_name, COUNT(*) as order_count FROM orders GROUP BY customer_name ORDER BY order_count DESC;"
76
+
77
+ elif any(kw in user_question_lower for kw in ['時間', '日期', '月份']):
78
+ return "SELECT strftime('%Y-%m', order_date) as month, COUNT(*) as orders FROM orders GROUP BY month ORDER BY month DESC;"
79
+
80
+ else:
81
+ return "SELECT '請提供更詳細的查詢條件' AS status;"
82
 
83
+ # 如果SQL不為空但缺少關鍵字
84
+ sql_upper = original_sql.upper()
85
 
86
+ if "SELECT" not in sql_upper and "FROM" not in sql_upper:
87
+ # 嘗試從問題推斷
88
+ if "count" in user_question_lower or "多少" in user_question_lower:
89
+ return f"SELECT COUNT(*) as count FROM appropriate_table WHERE condition; -- 原始SQL: {original_sql}"
90
+ else:
91
+ return f"SELECT * FROM appropriate_table WHERE condition; -- 原始SQL: {original_sql}"
92
 
93
+ elif "SELECT" not in sql_upper and "FROM" in sql_upper:
94
+ return "SELECT * " + original_sql
 
95
 
96
+ elif "SELECT" in sql_upper and "FROM" not in sql_upper:
97
+ # 嘗試找到合適的FROM子句
98
+ if "customer" in user_question_lower or "客戶" in user_question_lower:
99
+ return original_sql + " FROM customers WHERE 1=1;"
100
+ elif "product" in user_question_lower or "產品" in user_question_lower:
101
+ return original_sql + " FROM products WHERE 1=1;"
102
+ elif "sale" in user_question_lower or "銷售" in user_question_lower:
103
+ return original_sql + " FROM sales WHERE 1=1;"
104
+ else:
105
+ return original_sql + " FROM appropriate_table WHERE 1=1;"
106
 
107
+ return original_sql # 如果不需要修復
 
108
 
109
  # ==================== 數據加載模塊 ====================
110
  class DataLoader:
 
121
  raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
122
 
123
  print("正在解析 messages 格式...")
124
+ valid_count = 0
125
+ invalid_count = 0
126
+
127
  for item in raw_dataset:
128
  try:
129
  if 'messages' in item and len(item['messages']) >= 2:
 
140
  sql_query = sql_match.group(1).strip()
141
  sql_query = re.sub(r'^sql\s*', '', sql_query)
142
  sql_query = re.sub(r'```sql|```', '', sql_query).strip()
 
 
 
 
 
 
143
  else:
144
+ sql_query = assistant_content
145
+
146
+ # 驗證SQL
147
+ validation = validate_sql(sql_query)
148
+ if not validation["valid"]:
149
+ invalid_count += 1
150
+ print(f"發現無效SQL [{invalid_count}]: {sql_query}")
151
+ # 暫時不修復,等待使用時再智能修復
152
 
153
  self.questions.append(question)
154
  self.sql_answers.append(sql_query)
155
+ valid_count += 1
156
 
157
  except Exception as e:
158
  continue
159
 
160
+ print(f"成功解析 {valid_count} 條問答範例,其中 {invalid_count} 條需要修復")
161
  return True
162
 
163
  except Exception as e:
 
200
  context += f"-- {col_name} ({col_type}) - {col_desc}\n"
201
  return context
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  # ==================== 主系統 ====================
204
  class TextToSQLSystem:
205
  def __init__(self, hf_token: str):
206
  self.hf_token = hf_token
207
  self.data_loader = DataLoader(hf_token)
208
+ self.llm_client = None # 延遲加載
209
  self.retrieval_system = RetrievalSystem()
210
 
211
  self.initialize_system()
 
221
 
222
  print("系統初始化完成")
223
  print(f"可用問題數量: {len(self.data_loader.questions)}")
224
+
225
+ def get_llm_client(self):
226
+ """延遲加載LLM客戶端"""
227
+ if self.llm_client is None:
228
+ self.llm_client = LLMClient(self.hf_token)
229
+ return self.llm_client
230
 
231
  def generate_sql(self, user_question: str) -> Tuple[str, str]:
232
  """生成SQL查詢"""
 
249
  log_messages.append(f"📊 相似度: {similarity_score:.3f}")
250
 
251
  if similarity_score > SIMILARITY_THRESHOLD:
252
+ # 驗證SQL
253
  validation = validate_sql(original_sql)
254
  if not validation["valid"]:
255
  log_messages.append(f"⚠️ 原始SQL有問題: {', '.join(validation['issues'])}")
256
+ log_messages.append("🛠️ 正在智能修復SQL...")
257
+
258
+ # 智能修復
259
+ repaired_sql = intelligent_sql_repair(original_sql, user_question, similar_question)
260
  log_messages.append(f"✅ 修復完成")
261
+
262
+ # 驗證修復後的SQL
263
+ final_validation = validate_sql(repaired_sql)
264
+ if not final_validation["valid"]:
265
+ log_messages.append(f"❌ 修復後仍有問題: {', '.join(final_validation['issues'])}")
266
+ else:
267
+ log_messages.append("✅ 修復後SQL驗證通過")
268
+
269
  return repaired_sql, "\n".join(log_messages)
270
  else:
271
  log_messages.append(f"✅ 相似度 > {SIMILARITY_THRESHOLD},直接返回")
 
276
  # 2. LLM生成模式
277
  log_messages.append("🤖 進入LLM生成模式...")
278
 
279
+ prompt = f"""你是一個SQL專家。請為以下問題生成SQL查詢:
280
+
281
+ 問題:{user_question}
282
+
283
+ 要求:
284
+ 1. 只輸出SQL語句
285
+ 2. 必須包含SELECT和FROM
286
+ 3. 使用正確的語法
287
+
288
+ SQL查詢:"""
289
+
290
+ generated_sql = self.get_llm_client().call_llm_api(prompt)
291
 
292
  if generated_sql:
293
  # 驗證生成的SQL
294
  validation = validate_sql(generated_sql)
295
  if not validation["valid"]:
296
  log_messages.append(f"⚠️ LLM生成的SQL有問題: {', '.join(validation['issues'])}")
297
+ generated_sql = intelligent_sql_repair(generated_sql, user_question, user_question)
 
298
 
299
  log_messages.append("✅ SQL生成完成")
300
  return generated_sql, "\n".join(log_messages)
 
304
  backup_sql = self.generate_backup_sql(user_question)
305
  return backup_sql, "\n".join(log_messages)
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  def generate_backup_sql(self, user_question: str) -> str:
308
  """生成備用SQL"""
309
  user_question_lower = user_question.lower()
310
 
311
+ if any(kw in user_question_lower for kw in ['報告', '完成', '', 'report']):
312
+ return "SELECT strftime('%Y-%m', completion_date) as month, COUNT(*) as report_count FROM reports GROUP BY month ORDER BY month;"
313
+
314
+ elif any(kw in user_question_lower for kw in ['銷售', '業績', '金額', 'sale']):
315
  return "SELECT product_name, SUM(sales_amount) as total_sales FROM sales GROUP BY product_name ORDER BY total_sales DESC LIMIT 10;"
316
+
317
+ elif any(kw in user_question_lower for kw in ['客戶', '買家', 'customer']):
318
  return "SELECT customer_name, COUNT(*) as order_count FROM orders GROUP BY customer_name ORDER BY order_count DESC;"
319
+
320
+ elif any(kw in user_question_lower for kw in ['時間', '日期', '月份', 'month']):
321
  return "SELECT strftime('%Y-%m', order_date) as month, COUNT(*) as orders FROM orders GROUP BY month ORDER BY month DESC;"
322
+
323
  else:
324
+ return "SELECT '請提供更詳細的查詢條件' AS status;"
325
+
326
+ # ==================== 其他類定義(保持不變) ====================
327
+ class LLMClient:
328
+ def __init__(self, hf_token: str):
329
+ self.hf_token = hf_token
330
+
331
+ def call_llm_api(self, prompt: str, model_urls: List[str] = LLM_MODELS) -> Optional[str]:
332
+ headers = {"Authorization": f"Bearer {self.hf_token}"}
333
+ payload = {
334
+ "inputs": prompt,
335
+ "parameters": {
336
+ "max_new_tokens": 200,
337
+ "temperature": 0.1,
338
+ "do_sample": False
339
+ }
340
+ }
341
+
342
+ for model_url in model_urls:
343
+ try:
344
+ response = requests.post(model_url, headers=headers, json=payload, timeout=20)
345
+ if response.status_code == 200:
346
+ result = response.json()
347
+ if isinstance(result, list) and len(result) > 0:
348
+ generated_text = result[0]['generated_text'].strip()
349
+ generated_text = re.sub(r'^```sql|```$', '', generated_text).strip()
350
+ return generated_text
351
+ except:
352
+ continue
353
+ return None
354
+
355
+ class RetrievalSystem:
356
+ def __init__(self):
357
+ self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
358
+ self.question_embeddings = None
359
+
360
+ def compute_embeddings(self, questions: List[str]) -> None:
361
+ if questions:
362
+ self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True, show_progress_bar=False)
363
+ else:
364
+ self.question_embeddings = torch.Tensor([])
365
+
366
+ def retrieve_similar(self, user_question: str, top_k: int = 3) -> List[Dict]:
367
+ if self.question_embeddings is None or len(self.question_embeddings) == 0:
368
+ return []
369
+ try:
370
+ question_embedding = self.embedder.encode(user_question, convert_to_tensor=True)
371
+ hits = util.semantic_search(question_embedding, self.question_embeddings, top_k=top_k)
372
+ return hits[0] if hits and hits[0] else []
373
+ except:
374
+ return []
375
 
376
  # ==================== 初始化系統 ====================
377
  print("正在初始化Text-to-SQL系統...")
 
379
 
380
  # ==================== Gradio界面 ====================
381
  def process_query(user_question: str) -> Tuple[str, str]:
 
382
  sql_result, log_message = text_to_sql_system.generate_sql(user_question)
 
 
383
  final_validation = validate_sql(sql_result)
 
384
 
385
  if not final_validation["valid"]:
386
  debug_info = "❌ 最終SQL驗證失敗:\n" + "\n".join(final_validation["issues"])
 
387
  else:
388
  debug_info = "✅ 最終SQL驗證通過"
 
 
389
 
390
  return sql_result, debug_info, log_message
391
 
 
392
  with gr.Blocks(title="智能Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
 
393
  gr.Markdown("# 🚀 智能 Text-to-SQL 系統")
 
 
 
 
 
 
 
 
394
 
395
  with gr.Row():
396
+ question_input = gr.Textbox(label="📝 您的問題", placeholder="例如:查詢2023年每月報告數量", lines=2)
397
  submit_btn = gr.Button("🚀 生成SQL", variant="primary")
 
398
 
399
  with gr.Row():
400
+ sql_output = gr.Code(label="📊 生成的SQL", language="sql", lines=6)
 
 
 
 
401
 
402
  with gr.Row():
403
+ debug_output = gr.Textbox(label="🔍 驗證信息", lines=2, interactive=False)
404
+ log_output = gr.Textbox(label="📋 執行日誌", lines=4, interactive=False)
 
 
 
405
 
406
+ submit_btn.click(process_query, inputs=question_input, outputs=[sql_output, debug_output, log_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
 
408
  if __name__ == "__main__":
 
 
 
 
 
409
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)