Paul720810 commited on
Commit
7431cad
·
verified ·
1 Parent(s): 5fc665e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -217
app.py CHANGED
@@ -14,14 +14,7 @@ import numpy as np
14
  # ==================== 配置區 ====================
15
  HF_TOKEN = os.environ.get("HF_TOKEN", None) # 建議從環境變數讀取
16
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
17
- SIMILARITY_THRESHOLD = 0.6
18
-
19
- # 多個備用LLM模型 (注意:在當前邏輯中並未使用)
20
- LLM_MODELS = [
21
- "https://api-inference.huggingface.co/models/gpt2",
22
- "https://api-inference.huggingface.co/models/distilgpt2",
23
- "https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
24
- ]
25
 
26
  print("=" * 60)
27
  print("🤖 智能 Text-to-SQL 系統啟動中...")
@@ -45,67 +38,48 @@ def validate_sql(sql_query: str) -> Dict:
45
  security_issues = []
46
  sql_upper = sql_clean.upper()
47
 
48
- # 檢查危險操作
49
  dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
50
  for keyword in dangerous_keywords:
51
  if f" {keyword} " in f" {sql_upper} ":
52
  security_issues.append(f"危險操作: {keyword}")
53
 
54
- # 檢查基本語法
55
  if "SELECT" not in sql_upper:
56
  security_issues.append("缺少SELECT")
57
-
58
  if "FROM" not in sql_upper:
59
  security_issues.append("缺少FROM")
60
 
61
  is_valid = not security_issues
62
  is_safe = all('危險' not in issue for issue in security_issues)
63
 
64
- return {
65
- "valid": is_valid,
66
- "issues": security_issues,
67
- "is_safe": is_safe,
68
- "empty": False
69
- }
70
 
71
  def analyze_question_type(question: str) -> Dict:
72
- """分析問題類型和關鍵詞"""
73
  question_lower = question.lower()
74
 
75
  analysis = {
76
  "type": "unknown",
77
  "keywords": [],
78
- "has_count": False,
79
- "has_date": False,
80
- "has_group": False,
81
- "has_comparison": False
82
  }
83
 
84
- # 檢測關鍵詞
85
- keywords_sets = {
86
- "sales": ["銷售", "業績", "金額", "收入", "sale", "revenue"],
87
- "customer": ["客戶", "買家", "用戶", "customer", "client"],
88
- "product": ["產品", "商品", "項目", "product", "item"],
89
- "time": ["時間", "日期", "月份", "年", "月", "最近", "date", "month", "year"],
90
- "report": ["報告", "完成", "", "report", "complete"],
91
- "count": ["多少", "幾個", "數量", "count", "how many"],
92
- "comparison": ["比較", "vs", " versus", "對比", "相比"]
93
- }
94
-
95
- for category, keywords in keywords_sets.items():
96
- if any(keyword in question_lower for keyword in keywords):
97
- analysis["keywords"].append(category)
98
-
99
- # 特殊檢測
100
- analysis["has_count"] = "count" in analysis["keywords"]
101
- analysis["has_date"] = "time" in analysis["keywords"]
102
- analysis["has_group"] = any(word in question_lower for word in ["每", "各", "group", "每個"])
103
- analysis["has_comparison"] = "comparison" in analysis["keywords"]
104
-
105
- # 確定主要類型
106
- if analysis["keywords"]:
107
- analysis["type"] = analysis["keywords"][0]
108
-
109
  return analysis
110
 
111
  # ==================== 完整數據加載模塊 ====================
@@ -114,77 +88,47 @@ class CompleteDataLoader:
114
  self.hf_token = hf_token
115
  self.questions = []
116
  self.sql_answers = []
117
- self.sql_quality = [] # 記錄每個SQL的質量評分
118
  self.schema_data = {}
119
 
120
  def load_complete_dataset(self) -> bool:
121
- """加載完整數據集(包括空白SQL)"""
122
  try:
123
  print(f"[{get_current_time()}] 正在加載完整數據集 '{DATASET_REPO_ID}'...")
124
  raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
125
 
126
- print("解析全部 messages 格式...")
127
- total_count, empty_count, valid_count = 0, 0, 0
128
-
129
  for item in raw_dataset:
130
  try:
131
  if 'messages' in item and len(item['messages']) >= 2:
132
  user_content = item['messages'][0]['content']
133
  assistant_content = item['messages'][1]['content']
134
 
135
- # 提取問題
136
  question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
137
  question = question_match.group(1).strip() if question_match else user_content
138
 
139
- # 提取SQL
140
  sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
141
- if sql_match:
142
- sql_query = sql_match.group(1).strip()
143
- sql_query = re.sub(r'^sql\s*', '', sql_query, flags=re.IGNORECASE)
144
- sql_query = re.sub(r'```sql|```', '', sql_query).strip()
145
- else:
146
- sql_query = assistant_content
147
-
148
- # 保存所有數據
149
  self.questions.append(question)
150
  self.sql_answers.append(sql_query)
151
-
152
- # 評估SQL質量
153
- validation = validate_sql(sql_query)
154
- quality_score = 1.0 if validation["valid"] else 0.3
155
- self.sql_quality.append(quality_score)
156
-
157
- total_count += 1
158
- if validation["empty"]:
159
- empty_count += 1
160
- if validation["valid"]:
161
- valid_count += 1
162
  except Exception:
163
  continue
164
 
165
- print(f"數據加載完成: 總數 {total_count}, 有效 {valid_count}, 空白 {empty_count}")
166
  return True
167
-
168
  except Exception as e:
169
  print(f"數據集加載失敗: {e}")
170
  return False
171
 
172
  def load_schema(self) -> bool:
173
- """加載數據庫Schema"""
174
  try:
175
- schema_file_path = hf_hub_download(
176
- repo_id=DATASET_REPO_ID,
177
- filename="sqlite_schema_FULL.json",
178
- repo_type='dataset',
179
- token=self.hf_token
180
- )
181
  with open(schema_file_path, 'r', encoding='utf-8') as f:
182
  self.schema_data = json.load(f)
183
  print("Schema加載成功")
184
  return True
185
  except Exception as e:
186
  print(f"Schema加載失敗: {e}")
187
- self.schema_data = {}
188
  return False
189
 
190
  # ==================== 檢索系統 ====================
@@ -197,19 +141,18 @@ class RetrievalSystem:
197
  print(f"SentenceTransformer 模型加載失敗: {e}")
198
  self.embedder = None
199
 
200
- def compute_embeddings(self, questions: List[str]) -> None:
201
  if self.embedder and questions:
202
  print(f"正在為 {len(questions)} 個問題計算向量...")
203
  self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
204
  print("向量計算完成")
205
 
206
- def retrieve_similar(self, user_question: str, top_k: int = 5) -> List[Dict]:
207
- if self.embedder is None or self.question_embeddings is None or len(self.question_embeddings) == 0:
208
- return []
209
  try:
210
  question_embedding = self.embedder.encode(user_question, convert_to_tensor=True)
211
  hits = util.semantic_search(question_embedding, self.question_embeddings, top_k=top_k)
212
- return hits[0] if hits and hits[0] else []
213
  except Exception as e:
214
  print(f"檢索錯誤: {e}")
215
  return []
@@ -223,138 +166,131 @@ class CompleteTextToSQLSystem:
223
  self.initialize_system()
224
 
225
  def initialize_system(self):
226
- """初始化系統組件"""
227
  print("正在初始化完整數據系統...")
228
-
229
  self.data_loader.load_complete_dataset()
230
  self.data_loader.load_schema()
231
-
232
- # 為所有問題計算向量
233
  if self.data_loader.questions:
234
  self.retrieval_system.compute_embeddings(self.data_loader.questions)
235
-
236
  print(f"系統初始化完成,載入問題總數: {len(self.data_loader.questions)}")
237
 
238
- # ===== 輔助函數 (作為類別方法) =====
239
- def get_available_tables(self) -> Dict:
240
- """從schema中獲取所有可用的表和欄位"""
241
- if not self.data_loader.schema_data:
242
- return {}
243
-
244
- tables = {}
245
- for table_name, columns_list in self.data_loader.schema_data.items():
246
- if isinstance(columns_list, list):
247
- column_names = [col["name"] for col in columns_list if "name" in col]
248
- tables[table_name] = column_names
249
-
250
- return tables
251
-
252
- def extract_number(self, text: str, default: int = 10) -> int:
253
- """從文字中提取數字"""
254
- numbers = re.findall(r'\d+', text)
255
- return int(numbers[0]) if numbers else default
256
 
257
  def generate_sql_from_question(self, question: str, analysis: Dict) -> str:
258
- """根據問題分析和真實Schema生成智能SQL"""
259
- question_lower = question.lower()
260
- available_tables = self.get_available_tables().keys()
261
-
262
- # 1. 每月/每日完成數量 - 使用 JobTimeline 相關表
263
- if any(kw in question_lower for kw in ["每月", "每日", "昨天", "完成"]) and analysis["has_count"]:
264
- group_match = re.search(r'([a-z]組)', question_lower)
265
- if group_match:
266
- group = group_match.group(1).replace('組', '').upper()
267
- group_mapping = {'A': 'TA', 'B': 'TB', 'C': 'TC', 'D': 'TD'}
268
- table_suffix = group_mapping.get(group, 'TA')
269
- table_name = f"JobTimeline_{table_suffix}"
270
-
271
- if "昨天" in question_lower:
272
- return f"SELECT COUNT(*) as 完成數量 FROM {table_name} WHERE DATE(end_time) = DATE('now','-1 day');"
273
- elif "每月" in question_lower:
274
- year_match = re.search(r'(\d{4})年?', question_lower)
275
- year = year_match.group(1) if year_match else datetime.now().strftime('%Y')
276
- return f"""SELECT strftime('%Y-%m', end_time) as 月份, COUNT(*) as 完成數量 FROM {table_name} WHERE strftime('%Y', end_time) = '{year}' AND end_time IS NOT NULL GROUP BY strftime('%Y-%m', end_time) ORDER BY 月份;"""
277
- return "SELECT strftime('%Y-%m', jt.end_time) as 月份, COUNT(*) as 完成數量 FROM JobTimeline jt WHERE jt.end_time IS NOT NULL GROUP BY strftime('%Y-%m', jt.end_time) ORDER BY 月份;"
278
 
279
- # 2. 評級分析 - 使用 TSR53SampleDescription.OverallRating
280
- elif any(kw in question_lower for kw in ["評級", "rating", "等級"]) and "TSR53SampleDescription" in available_tables:
281
- if any(kw in question_lower for kw in ["分佈", "統計", "多少"]):
282
- return "SELECT OverallRating as 評級, COUNT(*) as 數量, ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM TSR53SampleDescription), 2) as 百分比 FROM TSR53SampleDescription WHERE OverallRating IS NOT NULL GROUP BY OverallRating ORDER BY 數量 DESC;"
283
- elif "fail" in question_lower or "失敗" in question_lower:
284
- return "SELECT JobNo as 工作單號, ApplicantName as 申請方, OverallRating as 評級 FROM TSR53SampleDescription WHERE OverallRating = 'Fail' ORDER BY JobNo;"
285
-
286
- # 3. 金額相關查詢 - 使用 TSR53Invoice
287
- elif any(kw in question_lower for kw in ["金額", "總額", "收入", "invoice"]) and any(kw in question_lower for kw in ["最高", "最大", "top"]):
288
- limit_num = self.extract_number(question_lower, default=10)
289
- return f"""WITH JobTotalAmount AS (SELECT JobNo, SUM(LocalAmount) AS TotalAmount FROM (SELECT DISTINCT JobNo, InvoiceCreditNoteNo, LocalAmount FROM TSR53Invoice WHERE LocalAmount IS NOT NULL) GROUP BY JobNo) SELECT jta.JobNo as 工作單號, sd.ApplicantName as 申請方, jta.TotalAmount as 總金額 FROM JobTotalAmount jta JOIN TSR53SampleDescription sd ON sd.JobNo = jta.JobNo ORDER BY jta.TotalAmount DESC LIMIT {limit_num};"""
290
-
291
- # 4. 公司/客戶相關查詢
292
- elif any(kw in question_lower for kw in ["公司", "客戶", "申請方", "付款方"]):
293
- if any(kw in question_lower for kw in ["最多", "top", "排名"]):
294
- return "SELECT ApplicantName as 申請方名稱, COUNT(*) as 工作單數量 FROM TSR53SampleDescription WHERE ApplicantName IS NOT NULL GROUP BY ApplicantName ORDER BY 工作單數量 DESC LIMIT 10;"
295
- return "SELECT ApplicantName as 申請方, InvoiceToName as 付款方, COUNT(*) as 工作單數量 FROM TSR53SampleDescription WHERE ApplicantName IS NOT NULL GROUP BY ApplicantName, InvoiceToName ORDER BY 工作單數量 DESC;"
296
 
297
- # ... 其他規則可以繼續添加 ...
 
 
 
 
 
 
 
 
 
298
 
299
- # 預設查詢 - 顯示基本工作單資訊
300
- return "SELECT JobNo as 工作單號, ApplicantName as 申請方, InvoiceToName as 付款方, OverallRating as 評級 FROM TSR53SampleDescription LIMIT 20;"
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
- def repair_empty_sql(self, original_sql: str, user_question: str, similar_question: str) -> str:
303
- """修復空白或無效的SQL"""
304
- validation = validate_sql(original_sql)
305
- if not validation["valid"]:
306
- analysis = analyze_question_type(user_question)
307
- repaired_sql = self.generate_sql_from_question(user_question, analysis)
308
- return f"-- 根據類似問題 '{similar_question}' (原SQL無效) 自動生成的查詢\n{repaired_sql}"
309
- return original_sql
 
 
 
 
 
310
 
311
  def generate_sql(self, user_question: str) -> Tuple[str, str]:
312
- """主流程:生成SQL查詢"""
313
  log_messages = [f"⏰ {get_current_time()} 開始處理"]
314
 
315
  if not user_question or not user_question.strip():
316
  return "請輸入您的問題。", "錯誤: 問題為空"
317
 
318
  # 1. 檢索最相似的問題
319
- if self.data_loader.questions:
320
- hits = self.retrieval_system.retrieve_similar(user_question)
 
 
 
321
 
322
- if hits:
323
- best_hit = hits[0]
324
- similarity_score = best_hit['score']
325
  corpus_id = best_hit['corpus_id']
326
  similar_question = self.data_loader.questions[corpus_id]
327
  original_sql = self.data_loader.sql_answers[corpus_id]
328
 
329
- log_messages.append(f"🔍 檢索到最相似問題: '{similar_question}'")
330
- log_messages.append(f"📊 相似度: {similarity_score:.3f}")
331
-
332
- if similarity_score > SIMILARITY_THRESHOLD:
333
- repaired_sql = self.repair_empty_sql(original_sql, user_question, similar_question)
334
- log_messages.append(f"✅ 相似度高於閾值 {SIMILARITY_THRESHOLD},採用檢索結果。")
335
- return repaired_sql, "\n".join(log_messages)
336
  else:
337
- log_messages.append(f"ℹ️ 相似度低於閾值 {SIMILARITY_THRESHOLD},轉為智能生成。")
338
-
339
- # 2. 如果檢索失敗或相似度不足,智能生成SQL
340
- log_messages.append("🤖 找不到高相似度結果,啟用智能生成規則...")
 
 
341
  analysis = analyze_question_type(user_question)
342
- intelligent_sql = self.generate_sql_from_question(user_question, analysis)
 
343
 
344
- log_messages.append(f"📋 問題分析: {analysis['type']} 類型, 關鍵詞: {analysis['keywords']}")
345
  log_messages.append("✅ 智能生成完成。")
346
-
347
  return intelligent_sql, "\n".join(log_messages)
348
 
349
  # ==================== 初始化系統 ====================
350
- print("準備初始化 Text-to-SQL 系統...")
351
- # 檢查 HF_TOKEN 是否存在
352
  if HF_TOKEN is None:
353
- print("\n" + "="*60)
354
- print("⚠️ 警告: Hugging Face Token 未設置。")
355
- print("請在環境變數中設定 HF_TOKEN 才能從私人數據集下載資料。")
356
- print("="*60 + "\n")
357
- # 這裡可以選擇退出或繼續,但下載會失敗
358
  text_to_sql_system = None
359
  else:
360
  text_to_sql_system = CompleteTextToSQLSystem(HF_TOKEN)
@@ -366,51 +302,35 @@ def process_query(user_question: str) -> Tuple[str, str, str]:
366
  return "系統未初始化", error_msg, error_msg
367
 
368
  sql_result, log_message = text_to_sql_system.generate_sql(user_question)
369
- return sql_result, "✅ SQL生成完成", log_message
370
 
371
  with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
372
- gr.Markdown("# 🚀 智慧 Text-to-SQL 系統")
373
- gr.Markdown("📊 **模式**: 讀取雲端數據集並結合「檢索」與「規則生成」兩種模式。")
374
 
375
  with gr.Row():
376
- question_input = gr.Textbox(
377
- label="📝 請在此輸入您的問題",
378
- placeholder="例如:2023年每月完成多少份報告? 或 哪個客戶的訂單總金額最高?",
379
- lines=3,
380
- scale=4
381
- )
382
  submit_btn = gr.Button("🚀 生成SQL", variant="primary", scale=1)
383
 
384
  with gr.Accordion("🔍 結果與日誌", open=True):
385
- sql_output = gr.Code(
386
- label="📊 生成的SQL查詢",
387
- language="sql",
388
- lines=8
389
- )
390
  status_output = gr.Textbox(label="🔍 執行狀態", interactive=False)
391
- log_output = gr.Textbox(label="📋 詳細日誌", lines=5, interactive=False)
392
 
393
- # 預設範例
394
  gr.Examples(
395
  examples=[
396
- "昨天完成了多少個工作單?",
397
- "A組每月完成數量是多少?",
398
- "哪個申請方的失敗評級最多?",
399
- "找出總金額最高的10筆訂單",
400
- "統計所有評級的分佈"
401
  ],
402
  inputs=question_input
403
  )
404
 
405
- submit_btn.click(
406
- process_query,
407
- inputs=question_input,
408
- outputs=[sql_output, status_output, log_output]
409
- )
410
-
411
  if __name__ == "__main__":
412
- print("Gradio 介面啟動中...")
413
- if text_to_sql_system is None:
414
- print("無法啟動 Gradio,因為系統初始化失敗。")
415
- else:
416
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
14
  # ==================== 配置區 ====================
15
  HF_TOKEN = os.environ.get("HF_TOKEN", None) # 建議從環境變數讀取
16
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
17
+ SIMILARITY_THRESHOLD = 0.65 # 適度提高閾值,確保檢索到的問題意圖更一致
 
 
 
 
 
 
 
18
 
19
  print("=" * 60)
20
  print("🤖 智能 Text-to-SQL 系統啟動中...")
 
38
  security_issues = []
39
  sql_upper = sql_clean.upper()
40
 
 
41
  dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
42
  for keyword in dangerous_keywords:
43
  if f" {keyword} " in f" {sql_upper} ":
44
  security_issues.append(f"危險操作: {keyword}")
45
 
 
46
  if "SELECT" not in sql_upper:
47
  security_issues.append("缺少SELECT")
 
48
  if "FROM" not in sql_upper:
49
  security_issues.append("缺少FROM")
50
 
51
  is_valid = not security_issues
52
  is_safe = all('危險' not in issue for issue in security_issues)
53
 
54
+ return {"valid": is_valid, "issues": security_issues, "is_safe": is_safe, "empty": False}
 
 
 
 
 
55
 
56
  def analyze_question_type(question: str) -> Dict:
57
+ """增強的問題分析 - 更精確的意圖識別"""
58
  question_lower = question.lower()
59
 
60
  analysis = {
61
  "type": "unknown",
62
  "keywords": [],
63
+ "has_count": "多少" in question_lower or "幾個" in question_lower or "數量" in question_lower,
64
+ "has_date": "時間" in question_lower or "日期" in question_lower or "月份" in question_lower or "年" in question_lower,
65
+ "has_group": "每" in question_lower or "各" in question_lower or "分組" in question_lower,
66
+ "specific_intent": "general_query" # 新增:具體意圖,預設為通用查詢
67
  }
68
 
69
+ # **更精確的意圖識別**
70
+ if "每月" in question_lower and ("完成" in question_lower or "報告" in question_lower or "工作單" in question_lower):
71
+ analysis["specific_intent"] = "monthly_completion_count"
72
+ analysis["type"] = "time_series"
73
+ elif ("評級" in question_lower or "pass" in question_lower or "fail" in question_lower) and ("統計" in question_lower or "分佈" in question_lower or "多少" in question_lower):
74
+ analysis["specific_intent"] = "rating_distribution"
75
+ analysis["type"] = "statistics"
76
+ elif "金額" in question_lower and ("最高" in question_lower or "top" in question_lower or "排名" in question_lower):
77
+ analysis["specific_intent"] = "amount_ranking"
78
+ analysis["type"] = "ranking"
79
+ elif ("公司" in question_lower or "客戶" in question_lower or "申請方" in question_lower) and ("統計" in question_lower or "數量" in question_lower or "排名" in question_lower):
80
+ analysis["specific_intent"] = "company_statistics"
81
+ analysis["type"] = "statistics"
82
+
 
 
 
 
 
 
 
 
 
 
 
83
  return analysis
84
 
85
  # ==================== 完整數據加載模塊 ====================
 
88
  self.hf_token = hf_token
89
  self.questions = []
90
  self.sql_answers = []
91
+ self.sql_quality = []
92
  self.schema_data = {}
93
 
94
  def load_complete_dataset(self) -> bool:
 
95
  try:
96
  print(f"[{get_current_time()}] 正在加載完整數據集 '{DATASET_REPO_ID}'...")
97
  raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
98
 
 
 
 
99
  for item in raw_dataset:
100
  try:
101
  if 'messages' in item and len(item['messages']) >= 2:
102
  user_content = item['messages'][0]['content']
103
  assistant_content = item['messages'][1]['content']
104
 
 
105
  question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
106
  question = question_match.group(1).strip() if question_match else user_content
107
 
 
108
  sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
109
+ sql_query = sql_match.group(1).strip() if sql_match else assistant_content
110
+ sql_query = re.sub(r'```sql|```', '', sql_query).strip()
111
+
 
 
 
 
 
112
  self.questions.append(question)
113
  self.sql_answers.append(sql_query)
 
 
 
 
 
 
 
 
 
 
 
114
  except Exception:
115
  continue
116
 
117
+ print(f"數據加載完成: 總數 {len(self.questions)}")
118
  return True
 
119
  except Exception as e:
120
  print(f"數據集加載失敗: {e}")
121
  return False
122
 
123
  def load_schema(self) -> bool:
 
124
  try:
125
+ schema_file_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename="sqlite_schema_FULL.json", repo_type='dataset', token=self.hf_token)
 
 
 
 
 
126
  with open(schema_file_path, 'r', encoding='utf-8') as f:
127
  self.schema_data = json.load(f)
128
  print("Schema加載成功")
129
  return True
130
  except Exception as e:
131
  print(f"Schema加載失敗: {e}")
 
132
  return False
133
 
134
  # ==================== 檢索系統 ====================
 
141
  print(f"SentenceTransformer 模型加載失敗: {e}")
142
  self.embedder = None
143
 
144
+ def compute_embeddings(self, questions: List[str]):
145
  if self.embedder and questions:
146
  print(f"正在為 {len(questions)} 個問題計算向量...")
147
  self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
148
  print("向量計算完成")
149
 
150
+ def retrieve_similar(self, user_question: str, top_k: int = 1) -> List[Dict]:
151
+ if self.embedder is None or self.question_embeddings is None: return []
 
152
  try:
153
  question_embedding = self.embedder.encode(user_question, convert_to_tensor=True)
154
  hits = util.semantic_search(question_embedding, self.question_embeddings, top_k=top_k)
155
+ return hits[0] if hits else []
156
  except Exception as e:
157
  print(f"檢索錯誤: {e}")
158
  return []
 
166
  self.initialize_system()
167
 
168
  def initialize_system(self):
 
169
  print("正在初始化完整數據系統...")
 
170
  self.data_loader.load_complete_dataset()
171
  self.data_loader.load_schema()
 
 
172
  if self.data_loader.questions:
173
  self.retrieval_system.compute_embeddings(self.data_loader.questions)
 
174
  print(f"系統初始化完成,載入問題總數: {len(self.data_loader.questions)}")
175
 
176
+ def extract_year(self, text: str) -> str:
177
+ """從文字中提取年份,若無則返回當年"""
178
+ year_match = re.search(r'(\d{4})', text)
179
+ return year_match.group(1) if year_match else datetime.now().strftime('%Y')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  def generate_sql_from_question(self, question: str, analysis: Dict) -> str:
182
+ """通用SQL生成器 (作為最終備用)"""
183
+ # 此函數現在作為無法識別具體意圖時的通用後備方案
184
+ return f"""-- 通用查詢範本
185
+ SELECT
186
+ JobNo as 工作單號,
187
+ ApplicantName as 申請方,
188
+ OverallRating as 評級
189
+ FROM TSR53SampleDescription
190
+ LIMIT 20;"""
 
 
 
 
 
 
 
 
 
 
 
191
 
192
+ def intelligent_repair_sql(self, user_question: str, similar_question: str) -> str:
193
+ """智能修復SQL - 基於當前使用者問題的意圖"""
194
+ analysis = analyze_question_type(user_question)
195
+ intent = analysis["specific_intent"]
196
+
197
+ comment = f"-- 根據類似問題 '{similar_question}' (原SQL無效) 進行智能修復\n"
198
+
199
+ if intent == "monthly_completion_count":
200
+ year = self.extract_year(user_question)
201
+ return comment + f"""-- 查詢 {year} 年每月完成的工作單數量
202
+ SELECT
203
+ strftime('%Y-%m', jt.end_time) as 月份,
204
+ COUNT(*) as 完成數量
205
+ FROM JobTimeline jt
206
+ WHERE strftime('%Y', jt.end_time) = '{year}' AND jt.end_time IS NOT NULL
207
+ GROUP BY strftime('%Y-%m', jt.end_time)
208
+ ORDER BY 月份;"""
209
 
210
+ elif intent == "rating_distribution":
211
+ return comment + """-- 查詢評級分佈統計
212
+ SELECT
213
+ OverallRating as 評級,
214
+ COUNT(*) as 數量,
215
+ ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM TSR53SampleDescription WHERE OverallRating IS NOT NULL), 2) as 百分比
216
+ FROM TSR53SampleDescription
217
+ WHERE OverallRating IS NOT NULL
218
+ GROUP BY OverallRating
219
+ ORDER BY 數量 DESC;"""
220
 
221
+ elif intent == "amount_ranking":
222
+ return comment + """-- 查詢工作單金額排名
223
+ WITH JobTotalAmount AS (
224
+ SELECT JobNo, SUM(LocalAmount) AS TotalAmount
225
+ FROM (SELECT DISTINCT JobNo, InvoiceCreditNoteNo, LocalAmount FROM TSR53Invoice WHERE LocalAmount IS NOT NULL)
226
+ GROUP BY JobNo
227
+ )
228
+ SELECT
229
+ jta.JobNo as 工作單號,
230
+ sd.ApplicantName as 申請方,
231
+ jta.TotalAmount as 總金額
232
+ FROM JobTotalAmount jta
233
+ JOIN TSR53SampleDescription sd ON sd.JobNo = jta.JobNo
234
+ ORDER BY jta.TotalAmount DESC
235
+ LIMIT 10;"""
236
 
237
+ elif intent == "company_statistics":
238
+ return comment + """-- 查詢申請方工作單統計
239
+ SELECT
240
+ ApplicantName as 申請方名稱,
241
+ COUNT(*) as 工作單數量
242
+ FROM TSR53SampleDescription
243
+ WHERE ApplicantName IS NOT NULL
244
+ GROUP BY ApplicantName
245
+ ORDER BY 工作單數量 DESC
246
+ LIMIT 20;"""
247
+
248
+ # 如果無法判斷具體意圖,使用原始的通用生成邏輯
249
+ return comment + self.generate_sql_from_question(user_question, analysis)
250
 
251
  def generate_sql(self, user_question: str) -> Tuple[str, str]:
252
+ """主流程:生成SQL查詢 (改進版本)"""
253
  log_messages = [f"⏰ {get_current_time()} 開始處理"]
254
 
255
  if not user_question or not user_question.strip():
256
  return "請輸入您的問題。", "錯誤: 問題為空"
257
 
258
  # 1. 檢索最相似的問題
259
+ hits = self.retrieval_system.retrieve_similar(user_question)
260
+
261
+ if hits:
262
+ best_hit = hits[0]
263
+ similarity_score = best_hit['score']
264
 
265
+ log_messages.append(f"🔍 檢索到最相似問題 (相似度: {similarity_score:.3f})")
266
+
267
+ if similarity_score > SIMILARITY_THRESHOLD:
268
  corpus_id = best_hit['corpus_id']
269
  similar_question = self.data_loader.questions[corpus_id]
270
  original_sql = self.data_loader.sql_answers[corpus_id]
271
 
272
+ validation = validate_sql(original_sql)
273
+ if validation["valid"] and validation["is_safe"]:
274
+ log_messages.append("✅ 相似度高,且原SQL有效,直接採用。")
275
+ return original_sql, "\n".join(log_messages)
 
 
 
276
  else:
277
+ log_messages.append(f"⚠️ 相似度高,但原SQL無效 ({', '.join(validation['issues'])})。")
278
+ log_messages.append("🛠️ 啟用智能修復...")
279
+ repaired_sql = self.intelligent_repair_sql(user_question, similar_question)
280
+ return repaired_sql, "\n".join(log_messages)
281
+
282
+ log_messages.append("🤖 未找到高相似度或有效的範本,根據問題直接生成。")
283
  analysis = analyze_question_type(user_question)
284
+ # 直接使用修復邏輯來生成,因為它本身就是基於意圖的生成器
285
+ intelligent_sql = self.intelligent_repair_sql(user_question, "無相似問題")
286
 
287
+ log_messages.append(f"📋 問題意圖分析: {analysis['specific_intent']}")
288
  log_messages.append("✅ 智能生成完成。")
 
289
  return intelligent_sql, "\n".join(log_messages)
290
 
291
  # ==================== 初始化系統 ====================
 
 
292
  if HF_TOKEN is None:
293
+ print("\n" + "="*60 + "\n⚠️ 警告: Hugging Face Token 未設置。\n" + "="*60 + "\n")
 
 
 
 
294
  text_to_sql_system = None
295
  else:
296
  text_to_sql_system = CompleteTextToSQLSystem(HF_TOKEN)
 
302
  return "系統未初始化", error_msg, error_msg
303
 
304
  sql_result, log_message = text_to_sql_system.generate_sql(user_question)
305
+ return sql_result, "✅ 處理完成", log_message
306
 
307
  with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
308
+ gr.Markdown("# 🚀 智慧 Text-to-SQL 系統 (進階修復版)")
309
+ gr.Markdown("📊 **模式**: 結合「檢索驗證」與「意圖導向生成」,即使資料庫範本有誤也能提供準確查詢。")
310
 
311
  with gr.Row():
312
+ question_input = gr.Textbox(label="📝 請在此輸入您的問題", placeholder="例如:2024年每月完成多少份報告?", lines=3, scale=4)
 
 
 
 
 
313
  submit_btn = gr.Button("🚀 生成SQL", variant="primary", scale=1)
314
 
315
  with gr.Accordion("🔍 結果與日誌", open=True):
316
+ sql_output = gr.Code(label="📊 生成的SQL查詢", language="sql", lines=10)
 
 
 
 
317
  status_output = gr.Textbox(label="🔍 執行狀態", interactive=False)
318
+ log_output = gr.Textbox(label="📋 詳細日誌", lines=6, interactive=False)
319
 
 
320
  gr.Examples(
321
  examples=[
322
+ "2023 年每月完成多少份報告?",
323
+ "統計一下各種評級的分佈",
324
+ "找出總金額最高的5筆訂單來自哪個申請方",
325
+ "哪個客戶的工作單數量最多?"
 
326
  ],
327
  inputs=question_input
328
  )
329
 
 
 
 
 
 
 
330
  if __name__ == "__main__":
331
+ if text_to_sql_system:
332
+ print("Gradio 介面啟動中...")
 
 
333
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
334
+ else:
335
+ print("無法啟動 Gradio,因為系統初始化失敗。")
336
+