Paul720810 commited on
Commit
7adb5ab
·
verified ·
1 Parent(s): 0481392

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +260 -827
app.py CHANGED
@@ -12,9 +12,15 @@ from typing import List, Dict, Tuple, Optional
12
  import numpy as np
13
 
14
  # ==================== 配置區 ====================
15
- HF_TOKEN = os.environ.get("HF_TOKEN", None) # 建議從環境變數讀取
16
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
17
- SIMILARITY_THRESHOLD = 0.65 # 適度提高閾值,確保檢索到的問題意圖更一致
 
 
 
 
 
 
18
 
19
  # 雲端環境檢測
20
  IS_SPACES = os.environ.get("SPACE_ID") is not None
@@ -30,831 +36,271 @@ print("=" * 60)
30
  # ==================== 獨立工具函數 (不依賴類別實例) ====================
31
  def get_current_time():
32
  """獲取當前時間字串"""
33
- return datetime.now().strftime("%H:%M:%S")
34
-
35
- def validate_sql(sql_query: str) -> Dict:
36
- """驗證SQL語句的語法和安全性"""
37
- if not sql_query or not sql_query.strip():
38
- return {"valid": False, "issues": ["SQL語句為空"], "is_safe": False, "empty": True}
39
-
40
- sql_clean = sql_query.strip()
41
- if len(sql_clean) < 5:
42
- return {"valid": False, "issues": ["SQL過短"], "is_safe": False, "empty": True}
43
-
44
- security_issues = []
45
- sql_upper = sql_clean.upper()
46
-
47
- dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
48
- for keyword in dangerous_keywords:
49
- if f" {keyword} " in f" {sql_upper} ":
50
- security_issues.append(f"危險操作: {keyword}")
51
-
52
- if "SELECT" not in sql_upper:
53
- security_issues.append("缺少SELECT")
54
- if "FROM" not in sql_upper:
55
- security_issues.append("缺少FROM")
56
-
57
- is_valid = not security_issues
58
- is_safe = all('危險' not in issue for issue in security_issues)
59
-
60
- return {"valid": is_valid, "issues": security_issues, "is_safe": is_safe, "empty": False}
61
-
62
- def analyze_question_type(question: str) -> Dict:
63
- """增強的問題分析 - 更精確的意圖識別"""
64
- question_lower = question.lower()
65
-
66
- analysis = {
67
- "type": "unknown",
68
- "keywords": [],
69
- "has_count": "多少" in question_lower or "幾個" in question_lower or "數量" in question_lower or "count" in question_lower,
70
- "has_date": "時間" in question_lower or "日期" in question_lower or "月份" in question_lower or "年" in question_lower or "yesterday" in question_lower or "昨天" in question_lower,
71
- "has_group": "每" in question_lower or "各" in question_lower or "分組" in question_lower or "group" in question_lower,
72
- "specific_intent": "general_query" # 新增:具體意圖,預設為通用查詢
73
- }
74
-
75
- # **更精確的意圖識別 - 增加更多模式**
76
- if ("每月" in question_lower or "monthly" in question_lower) and ("完成" in question_lower or "completed" in question_lower or "報告" in question_lower or "工作單" in question_lower):
77
- analysis["specific_intent"] = "monthly_completion_count"
78
- analysis["type"] = "time_series"
79
- elif ("評級" in question_lower or "pass" in question_lower or "fail" in question_lower or "rating" in question_lower) and ("統計" in question_lower or "分佈" in question_lower or "多少" in question_lower or "distribution" in question_lower):
80
- analysis["specific_intent"] = "rating_distribution"
81
- analysis["type"] = "statistics"
82
- elif ("金額" in question_lower or "amount" in question_lower or "價格" in question_lower or "費用" in question_lower) and ("最高" in question_lower or "top" in question_lower or "排名" in question_lower or "highest" in question_lower):
83
- analysis["specific_intent"] = "amount_ranking"
84
- analysis["type"] = "ranking"
85
- elif ("公司" in question_lower or "客戶" in question_lower or "申請方" in question_lower or "company" in question_lower or "client" in question_lower) and ("統計" in question_lower or "數量" in question_lower or "排名" in question_lower or "count" in question_lower):
86
- analysis["specific_intent"] = "company_statistics"
87
- analysis["type"] = "statistics"
88
- elif ("實驗室" in question_lower or "lab" in question_lower or "組" in question_lower) and ("完成" in question_lower or "completed" in question_lower):
89
- analysis["specific_intent"] = "lab_completion"
90
- analysis["type"] = "lab_specific"
91
- elif ("異常" in question_lower or "超過" in question_lower or "延遲" in question_lower or "slow" in question_lower or "long" in question_lower):
92
- analysis["specific_intent"] = "anomaly_detection"
93
- analysis["type"] = "analysis"
94
- elif ("買方" in question_lower or "buyer" in question_lower) and ("完成" in question_lower or "completed" in question_lower):
95
- analysis["specific_intent"] = "buyer_specific"
96
- analysis["type"] = "buyer_analysis"
97
- elif ("耗時" in question_lower or "時間" in question_lower or "duration" in question_lower or "time" in question_lower) and ("最久" in question_lower or "longest" in question_lower):
98
- analysis["specific_intent"] = "duration_analysis"
99
- analysis["type"] = "time_analysis"
100
-
101
- # 提取關鍵詞以供後續使用
102
- keywords = []
103
- # 公司/品牌名稱
104
- brand_patterns = [r"puma", r"under armour", r"skechers", r"nike", r"adidas"]
105
- for pattern in brand_patterns:
106
- if re.search(pattern, question_lower):
107
- keywords.append(pattern.replace(" ", "_"))
108
-
109
- # 實驗室組別
110
- lab_patterns = [r"[a-e]組", r"ta", r"tb", r"tc", r"td", r"te"]
111
- for pattern in lab_patterns:
112
- if re.search(pattern, question_lower):
113
- keywords.append(pattern)
114
-
115
- analysis["keywords"] = keywords
116
- return analysis
117
-
118
- # ==================== 完整數據加載模塊 ====================
119
- class CompleteDataLoader:
120
- def __init__(self, hf_token: str):
121
- self.hf_token = hf_token
122
- self.questions = []
123
- self.sql_answers = []
124
- self.sql_quality = []
125
- self.schema_data = {}
126
-
127
- def preview_dataset_structure(self, sample_size: int = 5) -> None:
128
- """預覽數據集結構以幫助調試"""
129
- try:
130
- print(f"📋 預覽數據集結構 (前 {sample_size} 個範例)...")
131
- raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
132
-
133
- for i in range(min(sample_size, len(raw_dataset))):
134
- item = raw_dataset[i]
135
- print(f"\n--- 範例 {i+1} ---")
136
- if 'messages' in item:
137
- user_content = item['messages'][0]['content']
138
- assistant_content = item['messages'][1]['content']
139
- print(f"User: {user_content[:120]}...")
140
- print(f"Assistant: {assistant_content[:120]}...")
141
-
142
- # 檢查SQL代碼塊
143
- sql_block_match = re.search(r'```sql\s*(.*?)\s*```', assistant_content, re.DOTALL)
144
- if sql_block_match:
145
- sql_content = sql_block_match.group(1).strip()
146
- print(f"✅ 找到SQL代碼塊: {sql_content[:60]}...")
147
- else:
148
- print("❌ 未找到SQL代碼塊")
149
-
150
- # 檢查是否有其他SQL格式
151
- if 'SELECT' in assistant_content.upper():
152
- print("⚠️ 但包含SELECT關鍵字")
153
- if 'SQL查詢:' in assistant_content:
154
- print("⚠️ 但包含'SQL查詢:'標記")
155
-
156
- # 檢查是否為JSON格式
157
- if assistant_content.strip().startswith('{'):
158
- try:
159
- json_data = json.loads(assistant_content)
160
- print(f"JSON Keys: {list(json_data.keys())}")
161
- except:
162
- print("JSON解析失敗")
163
- else:
164
- print(f"無messages字段: {list(item.keys())}")
165
-
166
- print(f"\n總數據量: {len(raw_dataset)} 項")
167
- except Exception as e:
168
- print(f"預覽失敗: {e}")
169
-
170
- def diagnose_data_issues(self, sample_size: int = 20) -> None:
171
- """診斷數據問題"""
172
  try:
173
- print(f"🔍 診斷數據問題 (檢查前 {sample_size} 個可能有問題的項目)...")
174
- raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
175
-
176
- issues_found = {"no_sql_block": 0, "empty_assistant": 0, "parsing_error": 0, "other": 0}
177
-
178
- for i in range(min(sample_size, len(raw_dataset))):
179
- item = raw_dataset[i]
180
- try:
181
- if 'messages' in item and len(item['messages']) >= 2:
182
- assistant_content = item['messages'][1]['content']
183
-
184
- # 檢查SQL代碼塊
185
- sql_block_match = re.search(r'```sql\s*(.*?)\s*```', assistant_content, re.DOTALL)
186
- if not sql_block_match:
187
- issues_found["no_sql_block"] += 1
188
- if issues_found["no_sql_block"] <= 3:
189
- print(f"\n❌ 無SQL代碼塊 #{i}: {assistant_content[:200]}...")
190
-
191
- if not assistant_content.strip():
192
- issues_found["empty_assistant"] += 1
193
-
194
- except Exception as e:
195
- issues_found["parsing_error"] += 1
196
- if issues_found["parsing_error"] <= 2:
197
- print(f"\n💥 解析錯誤 #{i}: {e}")
198
-
199
- print(f"\n📊 診斷結果:")
200
- for issue, count in issues_found.items():
201
- print(f" {issue}: {count}")
202
  except Exception as e:
203
- print(f"診斷失敗: {e}")
204
-
205
- def load_complete_dataset(self) -> bool:
206
- try:
207
- print(f"[{get_current_time()}] 正在加載完整數據集 '{DATASET_REPO_ID}'...")
208
- raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
209
-
210
- successful_loads = 0
211
- total_items = len(raw_dataset)
212
- skipped_reasons = {"empty_question": 0, "empty_sql": 0, "parse_error": 0, "invalid_format": 0, "json_parse_error": 0}
213
-
214
- for idx, item in enumerate(raw_dataset):
215
- try:
216
- if 'messages' in item and len(item['messages']) >= 2:
217
- user_content = item['messages'][0]['content']
218
- assistant_content = item['messages'][1]['content']
219
-
220
- # 多種問題提取策略
221
- question = None
222
-
223
- # 策略1: 檢查是否為JSON格式的回應
224
- try:
225
- if assistant_content.strip().startswith('{'):
226
- json_data = json.loads(assistant_content)
227
- if 'sql' in json_data:
228
- sql_query = json_data['sql']
229
- elif 'query' in json_data:
230
- sql_query = json_data['query']
231
- else:
232
- sql_query = None
233
-
234
- # 從JSON中提取問題 (如果有的話)
235
- if 'question' in json_data:
236
- question = json_data['question']
237
- elif 'user_query' in json_data:
238
- question = json_data['user_query']
239
- else:
240
- sql_query = None
241
- except json.JSONDecodeError:
242
- sql_query = None
243
-
244
- # 策略2: 標準「指令:」格式
245
- if not question:
246
- question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
247
- if question_match:
248
- question = question_match.group(1).strip()
249
-
250
- # 策略3: 如果沒找到,嘗試提取最後一行非空內容
251
- if not question:
252
- lines = [line.strip() for line in user_content.split('\n') if line.strip() and not line.startswith('#')]
253
- if lines:
254
- # 過濾掉看起來像標題的行
255
- for line in reversed(lines):
256
- if not line.startswith('###') and '?' in line and len(line) > 5:
257
- question = line
258
- break
259
- if not question and lines:
260
- question = lines[-1]
261
-
262
- # 策略4: 直接使用整個內容(作為最後手段)
263
- if not question:
264
- question = user_content.strip()
265
-
266
- # SQL提取邏輯(如果還沒從JSON中獲得)
267
- if not sql_query:
268
- # 策略1: SQL代碼塊格式(最常見)
269
- sql_block_match = re.search(r'```sql\s*(.*?)\s*```', assistant_content, re.DOTALL)
270
- if sql_block_match:
271
- sql_query = sql_block_match.group(1).strip()
272
-
273
- # 策略2: 標準「SQL查詢:」格式
274
- if not sql_query:
275
- sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n\n|$)', assistant_content, re.DOTALL)
276
- if sql_match:
277
- sql_query = sql_match.group(1).strip()
278
- # 清理可能的代碼塊標記
279
- sql_query = re.sub(r'```sql|```', '', sql_query).strip()
280
-
281
- # 策略3: 查找任何包含 SELECT 或 WITH 的多行內容
282
- if not sql_query:
283
- lines = assistant_content.split('\n')
284
- sql_lines = []
285
- in_sql_block = False
286
-
287
- for line in lines:
288
- line_upper = line.upper().strip()
289
- # 開始條件:找到SQL關鍵字
290
- if not in_sql_block and (line_upper.startswith('SELECT') or line_upper.startswith('WITH')):
291
- in_sql_block = True
292
- sql_lines.append(line)
293
- # 繼續條件:在SQL塊中
294
- elif in_sql_block:
295
- # 結束條件:空行或看起來不像SQL的行
296
- if not line.strip():
297
- break
298
- elif line.strip().startswith('```') and len(sql_lines) > 0:
299
- break
300
- elif line_upper.startswith('思考過程:') or line_upper.startswith('上下文:'):
301
- break
302
- else:
303
- sql_lines.append(line)
304
-
305
- if sql_lines:
306
- sql_query = '\n'.join(sql_lines).strip()
307
-
308
- # 策略4: 如果還是沒找到,嘗試更寬鬆的匹配
309
- if not sql_query:
310
- # 查找所有可能的SQL片段
311
- sql_patterns = [
312
- r'(SELECT.*?FROM.*?)(?:\n\n|$)',
313
- r'(WITH.*?SELECT.*?)(?:\n\n|$)',
314
- r'SQL查詢:\s*\n(.*?)(?:\n\n|$)'
315
- ]
316
-
317
- for pattern in sql_patterns:
318
- match = re.search(pattern, assistant_content, re.DOTALL | re.IGNORECASE)
319
- if match:
320
- candidate = match.group(1).strip()
321
- # 基本驗證
322
- if len(candidate) > 10 and ('SELECT' in candidate.upper() or 'WITH' in candidate.upper()):
323
- sql_query = candidate
324
- break
325
-
326
- # 清理SQL查詢
327
- if sql_query:
328
- # 移除各種標記
329
- sql_query = re.sub(r'```sql|```', '', sql_query).strip()
330
- sql_query = re.sub(r'^思考過程:.*?\n', '', sql_query, flags=re.MULTILINE).strip()
331
- sql_query = re.sub(r'^SQL查詢:\s*', '', sql_query, flags=re.MULTILINE).strip()
332
-
333
- # 移除多餘的空行
334
- sql_query = re.sub(r'\n\s*\n', '\n', sql_query).strip()
335
-
336
- # 確保SQL完整性 - 如果以分號結尾且內容合理,保留
337
- if not sql_query.endswith(';') and len(sql_query) > 20:
338
- # 檢查是否看起來像完整的SQL
339
- if 'FROM' in sql_query.upper() and sql_query.count('(') == sql_query.count(')'):
340
- sql_query += ';'
341
-
342
- # 清理問題文本
343
- if question:
344
- question = re.sub(r'^###\s*', '', question).strip()
345
- question = re.sub(r'Your JSON Response.*', '', question).strip()
346
- # 移除多餘的上下文���息
347
- question = re.sub(r'\n上下文:.*', '', question, flags=re.DOTALL).strip()
348
-
349
- # 數據質量驗證(降低標準以提高利用率)
350
- if not question or len(question.strip()) < 3:
351
- skipped_reasons["empty_question"] += 1
352
- continue
353
-
354
- if not sql_query or len(sql_query.strip()) < 8: # 進一步降低最小長度要求
355
- skipped_reasons["empty_sql"] += 1
356
- if idx < 10: # 調試:顯示前10個被跳過的SQL為空的案例
357
- print(f"SQL為空案例 {idx}: 原始助手回應前100字符: {assistant_content[:100]}...")
358
- continue
359
-
360
- # 更寬鬆的SQL驗證
361
- sql_upper = sql_query.upper()
362
- if "SELECT" not in sql_upper and "WITH" not in sql_upper and "CREATE" not in sql_upper:
363
- skipped_reasons["invalid_format"] += 1
364
- if idx < 5: # 調試:顯示前5個格式錯誤的案例
365
- print(f"格式錯誤案例 {idx}: SQL內容: {sql_query[:100]}...")
366
- continue
367
-
368
- self.questions.append(question)
369
- self.sql_answers.append(sql_query)
370
- successful_loads += 1
371
-
372
- # 調試:顯示前5個成功案例
373
- if successful_loads <= 5:
374
- print(f"✅ 成功案例 {successful_loads}:")
375
- print(f" 問題: {question[:80]}...")
376
- print(f" SQL: {sql_query[:80]}...")
377
-
378
- else:
379
- skipped_reasons["invalid_format"] += 1
380
-
381
- except json.JSONDecodeError as e:
382
- skipped_reasons["json_parse_error"] += 1
383
- continue
384
- except Exception as e:
385
- skipped_reasons["parse_error"] += 1
386
- if idx < 3: # 只顯示前3個錯誤
387
- print(f"跳過第 {idx} 項資料,錯誤: {e}")
388
- continue
389
-
390
- print(f"數據加載完成: 成功載入 {successful_loads}/{total_items} 項")
391
- print(f"跳過原因統計: 問題為空({skipped_reasons['empty_question']}) | SQL為空({skipped_reasons['empty_sql']}) | 格式錯誤({skipped_reasons['invalid_format']}) | JSON錯誤({skipped_reasons['json_parse_error']}) | 解析錯誤({skipped_reasons['parse_error']})")
392
- return successful_loads > 0
393
- except Exception as e:
394
- print(f"數據集加載失敗: {e}")
395
- return False
396
-
397
- def load_schema(self) -> bool:
398
  try:
399
- schema_file_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename="sqlite_schema_FULL.json", repo_type='dataset', token=self.hf_token)
400
- with open(schema_file_path, 'r', encoding='utf-8') as f:
401
- self.schema_data = json.load(f)
402
- print("Schema加載成功")
403
- return True
 
 
 
 
404
  except Exception as e:
405
- print(f"Schema加載失敗: {e}")
406
- return False
407
 
408
- # ==================== 檢索系統 ====================
409
- class RetrievalSystem:
410
- def __init__(self):
411
- try:
412
- # 根據環境選擇設備
413
- device = DEVICE if 'DEVICE' in globals() else 'cpu'
414
- print(f"🔧 初始化 SentenceTransformer (設備: {device})...")
415
- self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device=device)
416
- self.question_embeddings = None
417
- print("✅ SentenceTransformer 模型加載成功")
418
- except Exception as e:
419
- print(f"❌ SentenceTransformer 模型加載失敗: {e}")
420
- self.embedder = None
421
-
422
- def compute_embeddings(self, questions: List[str]):
423
- if self.embedder and questions:
424
- print(f"正在為 {len(questions)} 個問題計算向量...")
425
- try:
426
- # 雲端環境優化:分批處理以節省記憶體
427
- batch_size = 32 if IS_SPACES else 64
428
- self.question_embeddings = self.embedder.encode(
429
- questions,
430
- convert_to_tensor=True,
431
- show_progress_bar=True,
432
- batch_size=batch_size
433
- )
434
- print("向量計算完成")
435
- except Exception as e:
436
- print(f"向量計算失敗: {e}")
437
- # 降級處理:使用更小的批次大小
438
- try:
439
- print("嘗試使用較小批次大小重新計算...")
440
- self.question_embeddings = self.embedder.encode(
441
- questions,
442
- convert_to_tensor=True,
443
- show_progress_bar=True,
444
- batch_size=16
445
- )
446
- print("向量計算完成(降級模式)")
447
- except Exception as e2:
448
- print(f"向量計算徹底失敗: {e2}")
449
- self.question_embeddings = None
450
-
451
- def retrieve_similar(self, user_question: str, top_k: int = 1) -> List[Dict]:
452
- if self.embedder is None or self.question_embeddings is None: return []
453
- try:
454
- question_embedding = self.embedder.encode(user_question, convert_to_tensor=True)
455
- hits = util.semantic_search(question_embedding, self.question_embeddings, top_k=top_k)
456
- return hits[0] if hits else []
457
- except Exception as e:
458
- print(f"檢索錯誤: {e}")
459
  return []
460
-
461
- # ==================== 主系統 ====================
462
- class CompleteTextToSQLSystem:
463
- def __init__(self, hf_token: str):
464
- self.hf_token = hf_token
465
- self.data_loader = CompleteDataLoader(hf_token)
466
- self.retrieval_system = RetrievalSystem()
467
- self.initialize_system()
468
-
469
- def diagnose_data_issues(self, sample_size: int = 20) -> None:
470
- """診斷數據問題"""
471
- try:
472
- print(f"🔍 診斷數據問題 (檢查前 {sample_size} 個可能有問題的項目)...")
473
- raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
474
-
475
- issues_found = {"no_sql_block": 0, "empty_assistant": 0, "parsing_error": 0, "other": 0}
476
-
477
- for i in range(min(sample_size, len(raw_dataset))):
478
- item = raw_dataset[i]
479
- try:
480
- if 'messages' in item and len(item['messages']) >= 2:
481
- assistant_content = item['messages'][1]['content']
482
-
483
- # 檢查SQL代碼塊
484
- sql_block_match = re.search(r'```sql\s*(.*?)\s*```', assistant_content, re.DOTALL)
485
- if not sql_block_match:
486
- issues_found["no_sql_block"] += 1
487
- if issues_found["no_sql_block"] <= 3:
488
- print(f"\n❌ 無SQL代碼塊 #{i}: {assistant_content[:200]}...")
489
-
490
- if not assistant_content.strip():
491
- issues_found["empty_assistant"] += 1
492
-
493
- except Exception as e:
494
- issues_found["parsing_error"] += 1
495
- if issues_found["parsing_error"] <= 2:
496
- print(f"\n💥 解析錯誤 #{i}: {e}")
497
-
498
- print(f"\n📊 診斷結果:")
499
- for issue, count in issues_found.items():
500
- print(f" {issue}: {count}")
501
- except Exception as e:
502
- print(f"診斷失敗: {e}")
503
-
504
- def initialize_system(self):
505
- print("正在初始化完整數據系統...")
506
-
507
- # 首先預覽數據結構
508
- self.data_loader.preview_dataset_structure(3)
509
-
510
- # 診斷數據問題
511
- self.data_loader.diagnose_data_issues(10)
512
-
513
- # 然後加載數據
514
- self.data_loader.load_complete_dataset()
515
- self.data_loader.load_schema()
516
- if self.data_loader.questions:
517
- self.retrieval_system.compute_embeddings(self.data_loader.questions)
518
- print(f"系統初始化完成,載入問題總數: {len(self.data_loader.questions)}")
519
-
520
- def extract_year(self, text: str) -> str:
521
- """從文字中提取年份,若無則返回當年"""
522
- year_match = re.search(r'(\d{4})', text)
523
- return year_match.group(1) if year_match else datetime.now().strftime('%Y')
524
-
525
- def call_free_cloud_ai(self, user_question: str) -> str:
526
- """調用免費雲端AI生成SQL - 當本地方法無法處理時的備選方案"""
527
  try:
528
- # 構建包含schema的prompt
529
- schema_info = json.dumps(self.data_loader.schema_data, ensure_ascii=False, indent=2)
530
-
531
- prompt = f"""你是一個SQL專家。根據以下資料庫schema和用戶問題,生成準確的SQL查詢。
532
-
533
- 資料庫Schema:
534
- {schema_info}
535
-
536
- 用戶問題: {user_question}
537
-
538
- 請分析問題並生成對應的SQL查詢。只回傳SQL代碼,不要額外解釋。
539
-
540
- SQL查詢:"""
541
-
542
- # 使用 Hugging Face 免費 Inference API
543
- headers = {"Authorization": f"Bearer {self.hf_token}"} if self.hf_token else {}
544
-
545
- # 嘗試多個免費模型
546
- models_to_try = [
547
- "microsoft/DialoGPT-medium", # 對話模型
548
- "google/flan-t5-large", # 指令跟隨模型
549
- "bigscience/bloom-560m" # 通用生成模型
550
- ]
551
-
552
- for model in models_to_try:
553
- try:
554
- url = f"https://api-inference.huggingface.co/models/{model}"
555
- response = requests.post(
556
- url,
557
- headers=headers,
558
- json={"inputs": prompt, "parameters": {"max_length": 512, "temperature": 0.1}},
559
- timeout=30
560
- )
561
-
562
- if response.status_code == 200:
563
- result = response.json()
564
- if isinstance(result, list) and len(result) > 0:
565
- generated_text = result[0].get('generated_text', '')
566
- # 提取SQL部分
567
- sql_match = re.search(r'SELECT.*?;', generated_text, re.DOTALL | re.IGNORECASE)
568
- if sql_match:
569
- return f"-- 由免費雲端AI ({model}) 生成\n{sql_match.group(0)}"
570
-
571
- except Exception as e:
572
- print(f"模型 {model} 調用失敗: {e}")
573
- continue
574
 
575
- # 如果所有模型都失敗,返回基於意圖的本地生成
576
- return self.generate_fallback_sql(user_question)
 
577
 
578
- except Exception as e:
579
- print(f"雲端AI調用失敗: {e}")
580
- return self.generate_fallback_sql(user_question)
581
-
582
- def generate_fallback_sql(self, user_question: str) -> str:
583
- """當所有方法都失敗時的後備SQL生成"""
584
- analysis = analyze_question_type(user_question)
585
-
586
- # 基於關鍵詞的簡單SQL生成
587
- question_lower = user_question.lower()
588
-
589
- if "工作單" in question_lower or "job" in question_lower:
590
- if "數量" in question_lower or "多少" in question_lower:
591
- return """-- 後備方案:工作單數量查詢
592
- SELECT COUNT(*) as 工作單總數
593
- FROM TSR53SampleDescription
594
- WHERE ApplicantName IS NOT NULL;"""
595
- else:
596
- return """-- 後備方案:工作單列表查詢
597
- SELECT JobNo, ApplicantName, BuyerName, OverallRating
598
- FROM TSR53SampleDescription
599
- WHERE ApplicantName IS NOT NULL
600
- LIMIT 20;"""
601
-
602
- elif "評級" in question_lower or "rating" in question_lower:
603
- return """-- 後備方案:評級統計查詢
604
- SELECT OverallRating, COUNT(*) as 數量
605
- FROM TSR53SampleDescription
606
- WHERE OverallRating IS NOT NULL
607
- GROUP BY OverallRating;"""
608
-
609
- elif "金額" in question_lower or "amount" in question_lower:
610
- return """-- 後備方案:金額統計查詢
611
- SELECT JobNo, LocalAmount
612
- FROM TSR53Invoice
613
- WHERE LocalAmount IS NOT NULL
614
- ORDER BY LocalAmount DESC
615
- LIMIT 10;"""
616
-
617
- # 默認通用查詢
618
- return """-- 後備方案:通用查詢
619
- SELECT JobNo, ApplicantName, BuyerName
620
- FROM TSR53SampleDescription
621
- LIMIT 10;"""
622
-
623
- def intelligent_repair_sql(self, user_question: str, similar_question: str) -> str:
624
- """智能修復SQL - 基於當前使用者問題的意圖 (擴展版本)"""
625
- analysis = analyze_question_type(user_question)
626
- intent = analysis["specific_intent"]
627
- keywords = analysis["keywords"]
628
-
629
- if similar_question != "無相似問題":
630
- comment = f"-- 根據類似問題 '{similar_question}' (原SQL無效) 進行智能修復\n"
631
  else:
632
- comment = f"-- 根據問題意圖 '{intent}' 智能生成SQL\n"
633
-
634
- if intent == "monthly_completion_count":
635
- year = self.extract_year(user_question)
636
- return comment + f"""-- 查詢 {year} 年每月完成的工作單數量
637
- SELECT
638
- strftime('%Y-%m', jt.ReportAuthorization) as 月份,
639
- COUNT(*) as 完成數量
640
- FROM JobTimeline jt
641
- WHERE strftime('%Y', jt.ReportAuthorization) = '{year}'
642
- AND jt.ReportAuthorization IS NOT NULL
643
- GROUP BY strftime('%Y-%m', jt.ReportAuthorization)
644
- ORDER BY 月份;"""
645
-
646
- elif intent == "lab_completion":
647
- # 實驗室特定查詢
648
- lab_mapping = {"a組": "TA", "b組": "TB", "c組": "TC", "d組": "TD", "e組": "TE"}
649
- lab_code = None
650
- for chinese, code in lab_mapping.items():
651
- if chinese in user_question.lower():
652
- lab_code = code
653
- break
654
-
655
- if lab_code:
656
- return comment + f"""-- 查詢{lab_code}實驗室完成的測試項目
657
- SELECT COUNT(*) as 完成數量
658
- FROM JobTimeline_{lab_code}
659
- WHERE DATE(end_time) = DATE('now','-1 day');"""
660
- else:
661
- return comment + """-- 通用實驗室查詢
662
- SELECT COUNT(*) as 總完成數量
663
- FROM JobTimeline
664
- WHERE ReportAuthorization IS NOT NULL;"""
665
-
666
- elif intent == "buyer_specific":
667
- # 買方特定查詢
668
- buyer_name = "Unknown"
669
- for keyword in keywords:
670
- if keyword in ["puma", "under_armour", "skechers", "nike", "adidas"]:
671
- buyer_name = keyword.replace("_", " ").title()
672
- break
673
-
674
- return comment + f"""-- 查詢買方 {buyer_name} 的已完成工作單
675
- SELECT sd.JobNo, sd.BuyerName, jt.ReportAuthorization
676
- FROM TSR53SampleDescription sd
677
- JOIN JobTimeline jt ON jt.JobNo = sd.JobNo
678
- WHERE sd.BuyerName LIKE '%{buyer_name}%'
679
- AND jt.ReportAuthorization IS NOT NULL
680
- ORDER BY jt.ReportAuthorization DESC;"""
681
-
682
- elif intent == "duration_analysis":
683
- return comment + """-- 查詢從 LabIn 到 LabOut 耗時最久的工作單
684
- SELECT JobNo,
685
- ROUND(julianday(LabOut) - julianday(LabIn), 2) AS 耗時天數
686
- FROM JobTimeline
687
- WHERE LabIn IS NOT NULL AND LabOut IS NOT NULL
688
- ORDER BY 耗時天數 DESC
689
- LIMIT 5;"""
690
-
691
- elif intent == "anomaly_detection":
692
- return comment + """-- 查詢從創建到授權超過 14 天的異常工單
693
- SELECT JobNo,
694
- ROUND(julianday(ReportAuthorization) - julianday(JobCreation), 2) AS 處理天數
695
- FROM JobTimeline
696
- WHERE JobCreation IS NOT NULL
697
- AND ReportAuthorization IS NOT NULL
698
- AND (julianday(ReportAuthorization) - julianday(JobCreation)) > 14
699
- ORDER BY 處理天數 DESC
700
- LIMIT 20;"""
701
-
702
- elif intent == "rating_distribution":
703
- return comment + """-- 查詢評級分佈統計
704
- SELECT
705
- OverallRating as 評級,
706
- COUNT(*) as 數量,
707
- ROUND(COUNT(*) * 100.0 / (
708
- SELECT COUNT(*)
709
- FROM TSR53SampleDescription
710
- WHERE OverallRating IS NOT NULL
711
- ), 2) as 百分比
712
- FROM TSR53SampleDescription
713
- WHERE OverallRating IS NOT NULL
714
- GROUP BY OverallRating
715
- ORDER BY 數量 DESC;"""
716
-
717
- elif intent == "amount_ranking":
718
- return comment + """-- 查詢工作單金額排名
719
- WITH JobTotalAmount AS (
720
- SELECT JobNo, SUM(LocalAmount) AS TotalAmount
721
- FROM (
722
- SELECT DISTINCT JobNo, InvoiceCreditNoteNo, LocalAmount
723
- FROM TSR53Invoice
724
- WHERE LocalAmount IS NOT NULL
725
- )
726
- GROUP BY JobNo
727
- )
728
- SELECT
729
- jta.JobNo as 工作單號,
730
- sd.ApplicantName as 申請方,
731
- jta.TotalAmount as 總金額
732
- FROM JobTotalAmount jta
733
- JOIN TSR53SampleDescription sd ON sd.JobNo = jta.JobNo
734
- WHERE sd.ApplicantName IS NOT NULL
735
- ORDER BY jta.TotalAmount DESC
736
- LIMIT 10;"""
737
-
738
- elif intent == "company_statistics":
739
- return comment + """-- 查詢申請方工作單統計
740
- SELECT
741
- ApplicantName as 申請方名稱,
742
- COUNT(*) as 工作單數量
743
- FROM TSR53SampleDescription
744
- WHERE ApplicantName IS NOT NULL
745
- GROUP BY ApplicantName
746
- ORDER BY 工作單數量 DESC
747
- LIMIT 20;"""
748
-
749
- # 通用查詢模板
750
- return comment + """-- 通用查詢範本
751
- SELECT
752
- JobNo as 工作單號,
753
- ApplicantName as 申請方,
754
- BuyerName as 買方,
755
- OverallRating as 評級
756
- FROM TSR53SampleDescription
757
- WHERE ApplicantName IS NOT NULL
758
- LIMIT 20;"""
759
-
760
- def generate_sql(self, user_question: str) -> Tuple[str, str]:
761
- """主流程:生成SQL查詢 (雲端AI增強版本)"""
762
- log_messages = [f"⏰ {get_current_time()} 開始處理問題: '{user_question[:50]}...'"]
763
-
764
- if not user_question or not user_question.strip():
765
- return "-- 錯誤: 請輸入有效問題\nSELECT '請輸入您的問題' as 錯誤信息;", "錯誤: 問題為空"
766
-
767
- # 1. 問題分析
768
- analysis = analyze_question_type(user_question)
769
- log_messages.append(f"📋 問題分析 - 意圖: {analysis['specific_intent']}, 類型: {analysis['type']}")
770
-
771
- # 2. 檢索最相似的問題
772
- hits = self.retrieval_system.retrieve_similar(user_question)
773
-
774
- if hits:
775
- best_hit = hits[0]
776
- similarity_score = best_hit['score']
777
- corpus_id = best_hit['corpus_id']
778
- similar_question = self.data_loader.questions[corpus_id]
779
-
780
- log_messages.append(f"🔍 找到相似問題 (相似度: {similarity_score:.3f}): '{similar_question[:50]}...'")
781
-
782
- # 降低相似度閾值,增加匹配機會
783
- if similarity_score > max(SIMILARITY_THRESHOLD - 0.1, 0.5):
784
- original_sql = self.data_loader.sql_answers[corpus_id]
785
- validation = validate_sql(original_sql)
786
-
787
- if validation["valid"] and validation["is_safe"]:
788
- log_messages.append("✅ 相似度較高且原SQL有效,直接採用")
789
- return original_sql, "\n".join(log_messages)
790
  else:
791
- log_messages.append(f"⚠️ 原SQL有問題: {', '.join(validation['issues'])}")
792
- log_messages.append("🛠️ 啟用智能修復...")
793
- repaired_sql = self.intelligent_repair_sql(user_question, similar_question)
794
- log_messages.append(" 智能修復完成")
795
- return repaired_sql, "\n".join(log_messages)
796
- else:
797
- log_messages.append(f"📉 相似度 ({similarity_score:.3f}) 較低,嘗試其他方法")
798
-
799
- # 3. 嘗試基於意圖的本地生成
800
- if analysis["specific_intent"] != "general_query":
801
- log_messages.append("🤖 使用意圖導向生成")
802
- intelligent_sql = self.intelligent_repair_sql(user_question, "無相似問題")
803
- validation = validate_sql(intelligent_sql)
804
-
805
- if validation["valid"]:
806
- log_messages.append("✅ 意圖導向生成成功")
807
- return intelligent_sql, "\n".join(log_messages)
808
- else:
809
- log_messages.append("⚠️ 意圖導向生成結果有問題,嘗試雲端AI")
810
-
811
- # 4. 調用免費雲端AI(針對未見過的問題)
812
- log_messages.append("🌐 調用免費雲端AI處理未見過的問題...")
813
- cloud_sql = self.call_free_cloud_ai(user_question)
814
- log_messages.append("✅ 雲端AI回應完成")
815
-
816
- return cloud_sql, "\n".join(log_messages)
817
-
818
- # ==================== 初始化系統 ====================
819
- if HF_TOKEN is None:
820
- print("\n" + "="*60 + "\n⚠️ 警告: Hugging Face Token 未設置。\n" + "="*60 + "\n")
821
- text_to_sql_system = None
822
- else:
823
- text_to_sql_system = CompleteTextToSQLSystem(HF_TOKEN)
824
-
825
- # ==================== Gradio界面 ====================
826
- def process_query(user_question: str) -> Tuple[str, str, str]:
827
- if text_to_sql_system is None:
828
- error_msg = "系統因缺少 Hugging Face Token 而未成功初始化。"
829
- return "系統未初始化", error_msg, error_msg
830
-
831
- sql_result, log_message = text_to_sql_system.generate_sql(user_question)
832
- return sql_result, "✅ 處理完成", log_message
833
-
834
- with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
835
- # 環境資訊顯示
836
- env_info = f"🌐 運行環境: {'Hugging Face Spaces' if IS_SPACES else '本地環境'} | 💻 設備: {DEVICE}"
837
- system_status = f"📊 已載入 {len(text_to_sql_system.data_loader.questions) if text_to_sql_system else 0} 個問答範例"
838
-
839
- gr.Markdown("# 🚀 智慧 Text-to-SQL 系統 (雲端版)")
840
- gr.Markdown("📊 **模式**: 結合「檢索驗證」與「意圖導向生成」,即使資料庫範本有誤也能提供準確查詢。")
841
- gr.Markdown(f"ℹ️ {env_info} | {system_status}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842
 
843
  with gr.Row():
844
- question_input = gr.Textbox(
845
- label="📝 請在此輸入您的問題",
846
- placeholder="例如:2024年每月完成多少份報告?",
847
- lines=3,
848
- scale=4
849
- )
850
- submit_btn = gr.Button("🚀 生成SQL", variant="primary", scale=1)
 
 
 
 
851
 
852
- with gr.Accordion("🔍 結果與日誌", open=True):
853
- sql_output = gr.Code(label="📊 生成的SQL查詢", language="sql", lines=10)
854
- status_output = gr.Textbox(label="🔍 執行狀態", interactive=False)
855
- log_output = gr.Textbox(label="📋 詳細日誌", lines=6, interactive=False)
856
 
857
- # 雲端環境優化的範例
858
  gr.Examples(
859
  examples=[
860
  "2024年每月完成多少份報告?",
@@ -892,21 +338,8 @@ if __name__ == "__main__":
892
  demo.launch(
893
  server_name="0.0.0.0",
894
  server_port=7860,
895
- share=False,
896
- show_error=True,
897
- quiet=False
898
  )
899
  else:
900
  # 本地環境
901
- print("🏠 在本地環境中啟動...")
902
- demo.launch(
903
- server_name="127.0.0.1",
904
- server_port=7860,
905
- share=True, # 本地環境可以選擇分享
906
- show_error=True
907
- )
908
- else:
909
- print("❌ 無法啟動 Gradio,因為系統初始化失敗。")
910
- if IS_SPACES:
911
- print("💡 請檢查 Hugging Face Spaces 的環境變數設置。")
912
-
 
12
  import numpy as np
13
 
14
  # ==================== 配置區 ====================
15
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
17
+
18
+ # === 修改開始 ===
19
+ # 我們不再需要硬性的相似度閾值,因為現在的策略是「參考」而非「直接採用」。
20
+ # SIMILARITY_THRESHOLD = 0.65
21
+ # 新增一個配置,決定要檢索多少個範例來當作參考
22
+ FEW_SHOT_EXAMPLES_COUNT = 2 # 檢索最相似的2個範例
23
+ # === 修改結束 ===
24
 
25
  # 雲端環境檢測
26
  IS_SPACES = os.environ.get("SPACE_ID") is not None
 
36
  # ==================== 獨立工具函數 (不依賴類別實例) ====================
37
  def get_current_time():
38
  """獲取當前時間字串"""
39
+ return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
40
+
41
+ def format_log(message: str, level: str = "INFO") -> str:
42
+ """格式化日誌訊息"""
43
+ return f"[{get_current_time()}] [{level.upper()}] {message}"
44
+
45
+ def parse_sql_from_response(response_text: str) -> Optional[str]:
46
+ """從API回應中提取SQL代碼"""
47
+ match = re.search(r"```sql\n(.*?)\n```", response_text, re.DOTALL)
48
+ if match:
49
+ return match.group(1).strip()
50
+ # 新增備用解析:如果找不到```sql ...```,直接嘗試解析JSON中的SQL
51
+ try:
52
+ data = json.loads(response_text)
53
+ if "SQL查詢" in data and "```sql" in data["SQL查詢"]:
54
+ match = re.search(r"```sql\n(.*?)\n```", data["SQL查詢"], re.DOTALL)
55
+ if match:
56
+ return match.group(1).strip()
57
+ except json.JSONDecodeError:
58
+ pass # 不是合法的JSON,忽略
59
+ return None
60
+
61
+ # ==================== 核心 Text-to-SQL 系統類別 ====================
62
+ class TextToSQLSystem:
63
+ def __init__(self, model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2'):
64
+ self.log_history = []
65
+ self._log("初始化系統...")
66
+ self.schema = self._load_schema()
67
+ self.model = SentenceTransformer(model_name, device=DEVICE)
68
+ self.dataset, self.corpus_embeddings = self._load_and_encode_dataset()
69
+ self._log(" 系統初始化完成,已準備就緒。")
70
+
71
+ def _log(self, message: str, level: str = "INFO"):
72
+ self.log_history.append(format_log(message, level))
73
+ print(format_log(message, level))
74
+
75
+ def _load_schema(self) -> Dict:
76
+ """從JSON檔案載入資料庫結構"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  try:
78
+ schema_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename="sqlite_schema_FULL.json", repo_type="dataset")
79
+ with open(schema_path, 'r', encoding='utf-8') as f:
80
+ self._log("成功載入資料庫結構 (sqlite_schema_FULL.json)")
81
+ return json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  except Exception as e:
83
+ self._log(f" 載入資料庫結構失敗: {e}", "ERROR")
84
+ return {}
85
+
86
+ def _format_schema_for_prompt(self) -> str:
87
+ """將 schema JSON 物件格式化為清晰的字串,用於提示"""
88
+ formatted_string = "資料庫結構 (Database Schema):\n"
89
+ for table_name, columns in self.schema.items():
90
+ formatted_string += f"Table: {table_name}\n"
91
+ for col in columns:
92
+ col_name = col.get('name', 'N/A')
93
+ col_type = col.get('type', 'N/A')
94
+ col_desc = col.get('description', '')
95
+ formatted_string += f" - {col_name} ({col_type}) # {col_desc}\n"
96
+ formatted_string += "\n"
97
+ return formatted_string
98
+
99
+ def _load_and_encode_dataset(self) -> Tuple[Optional[List[Dict]], Optional[torch.Tensor]]:
100
+ """載入訓練數據集並對問題進行編碼"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  try:
102
+ dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
103
+
104
+ # 提取所有 "user" 的 "content" 作為語料庫
105
+ corpus = [item['messages'][0]['content'] for item in dataset]
106
+
107
+ self._log(f"正在對 {len(corpus)} 個範例問題進行編碼...")
108
+ embeddings = self.model.encode(corpus, convert_to_tensor=True, device=DEVICE)
109
+ self._log("✅ 範例問題編碼完成。")
110
+ return dataset, embeddings
111
  except Exception as e:
112
+ self._log(f" 載入或編碼數據集失敗: {e}", "ERROR")
113
+ return None, None
114
 
115
+ def find_most_similar(self, question: str, top_k: int) -> List[Dict]:
116
+ """尋找最相似的K個問題及其對應的SQL"""
117
+ if self.corpus_embeddings is None or self.dataset is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  return []
119
+ question_embedding = self.model.encode(question, convert_to_tensor=True, device=DEVICE)
120
+ cos_scores = util.cos_sim(question_embedding, self.corpus_embeddings)[0]
121
+ top_results = torch.topk(cos_scores, k=min(top_k, len(self.corpus_embeddings)))
122
+
123
+ similar_examples = []
124
+ for score, idx in zip(top_results[0], top_results[1]):
125
+ item = self.dataset[idx.item()]
126
+ user_content = item['messages'][0]['content']
127
+ assistant_content = item['messages'][1]['content']
128
+
129
+ # 從 assistant_content 中提取純 SQL
130
+ sql_query = parse_sql_from_response(assistant_content)
131
+ if not sql_query:
132
+ # 如果解析失敗,可能是格式問題,這裡做個備份
133
+ sql_query = "無法解析範例SQL"
134
+
135
+ similar_examples.append({
136
+ "similarity": score.item(),
137
+ "question": user_content,
138
+ "sql": sql_query
139
+ })
140
+ return similar_examples
141
+
142
+ def huggingface_api_call(self, prompt: str) -> str:
143
+ """呼叫 Hugging Face Inference API"""
144
+ API_URL = "[https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1](https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1)"
145
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
146
+ payload = {
147
+ "inputs": prompt,
148
+ "parameters": {
149
+ "max_new_tokens": 1024,
150
+ "return_full_text": False
151
+ }
152
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  try:
154
+ self._log("正在呼叫 Hugging Face API...")
155
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
156
+ response.raise_for_status()
157
+ self._log("✅ API 成功回應。")
158
+ return response.json()[0]['generated_text']
159
+ except requests.exceptions.RequestException as e:
160
+ self._log(f"❌ API 呼叫失敗: {e}", "ERROR")
161
+ return f"API 錯誤: {e}"
162
+
163
+ # === 修改開始: 重寫核心處理邏輯 ===
164
+ def _build_prompt_for_generation(self, user_question: str, examples: List[Dict]) -> str:
165
+ """
166
+ **新增的函數**
167
+ 根據我們的「檢索-增強-生成」策略,建立一個豐富的提示(Prompt)。
168
+ """
169
+ # 1. 任務指令 (System Instruction)
170
+ # 明確告訴 AI 它的角色和目標。
171
+ system_instruction = (
172
+ "你是一位頂尖的資料庫專家,精通 SQLite。你的任務是根據使用者提出的問題,"
173
+ "參考提供的資料庫結構和相似的 SQL 查詢範例,生成��個精確、高效的 SQLite 查詢語法。\n"
174
+ "請將最終的 SQL 查詢語法包裝在 ```sql ... ``` 區塊中。"
175
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
+ # 2. 資料庫結構 (Database Schema)
178
+ # 讓 AI 了解有哪些資料表和欄位可用。
179
+ schema_string = self._format_schema_for_prompt()
180
 
181
+ # 3. 參考範例 (Few-shot Examples)
182
+ # 給 AI 看「過去的優良作業」,讓它學習語法風格和邏輯。
183
+ examples_string = "--- 參考範例 ---\n"
184
+ if not examples:
185
+ examples_string += "無\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  else:
187
+ for i, example in enumerate(examples, 1):
188
+ # 為了讓提示更清晰,我們只取範例中的 `指令` 部分
189
+ clean_question = re.search(r"指令:\s*(.*)", example['question'])
190
+ if clean_question:
191
+ question_to_show = clean_question.group(1).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  else:
193
+ question_to_show = example['question'] # 如果格式不符,顯示原文
194
+
195
+ examples_string += f"範例 {i}:\n"
196
+ examples_string += f" - 使用者問題: \"{question_to_show}\"\n"
197
+ examples_string += f" - SQL 查詢:\n```sql\n{example['sql']}\n```\n\n"
198
+
199
+ # 4. 新的使用者問題 (User's New Question)
200
+ # 這是 AI 這次需要解決的核心問題。
201
+ final_question_section = (
202
+ "--- 任務開始 ---\n"
203
+ f"請根據以上的資料庫結構和參考範例,為以下使用者問題生成 SQL 查詢:\n"
204
+ f"使用者問題: \"{user_question}\""
205
+ )
206
+
207
+ # 組合完整的提示
208
+ full_prompt = (
209
+ f"{system_instruction}\n\n"
210
+ f"{schema_string}\n"
211
+ f"{examples_string}"
212
+ f"{final_question_section}"
213
+ )
214
+
215
+ self._log("已建立給 AI 的完整提示 (Prompt):\n" + "="*20 + f"\n{full_prompt}\n" + "="*20)
216
+ return full_prompt
217
+
218
+ def process_question(self, question: str) -> Tuple[str, str]:
219
+ """
220
+ 處理使用者問題的核心函數。
221
+ 採用「檢索-增強-生成」(RAG) 流程。
222
+ """
223
+ self.log_history = [] # 清空上次日誌
224
+ self._log(f"⏰ 開始處理問題: '{question}'")
225
+
226
+ # 步驟 1: 檢索 (Retrieval)
227
+ # 無論如何,都先尋找最相似的範例作為參考資料。
228
+ self._log(f"🔍 正在從 {len(self.dataset)} 個範例中尋找最相似的 {FEW_SHOT_EXAMPLES_COUNT} 個參考...")
229
+ similar_examples = self.find_most_similar(question, top_k=FEW_SHOT_EXAMPLES_COUNT)
230
+
231
+ if similar_examples:
232
+ for ex in similar_examples:
233
+ self._log(f" - 找到相似範例 (相似度: {ex['similarity']:.3f}): '{ex['question'][:50]}...'")
234
+ else:
235
+ self._log(" - 未找到相似範例。", "WARNING")
236
+
237
+ # 步驟 2: 增強 (Augmentation)
238
+ # 建立一個包含所有必要資訊的豐富提示。
239
+ self._log("📝 正在建立給 AI 的完整提示 (Prompt)...")
240
+ prompt = self._build_prompt_for_generation(question, similar_examples)
241
+
242
+ # 步驟 3: 生成 (Generation)
243
+ # 將判斷權交給 AI,讓它根據完整的上下文生成 SQL。
244
+ self._log("🧠 將判斷權交給 AI,開始生成 SQL...")
245
+ api_response = self.huggingface_api_call(prompt)
246
+
247
+ # 處理並回傳結果
248
+ sql_query = parse_sql_from_response(api_response)
249
+
250
+ if sql_query:
251
+ self._log(f"✅ 成功從 AI 回應中解析出 SQL!")
252
+ status = "生成成功"
253
+ return sql_query, status
254
+ else:
255
+ self._log("❌ 未能從 AI 回應中解析出有效的 SQL。", "ERROR")
256
+ self._log(f" - AI 原始回應: {api_response}", "DEBUG")
257
+ status = "生成失敗"
258
+ return f"無法從 AI 的回應中提取 SQL。\n\n原始回應:\n{api_response}", status
259
+ # === 修改結束 ===
260
+
261
+
262
+ # ==================== Gradio 介面設定 ====================
263
+ text_to_sql_system = None
264
+ try:
265
+ text_to_sql_system = TextToSQLSystem()
266
+ except Exception as e:
267
+ print(f"初始化 TextToSQLSystem 失敗: {e}")
268
+
269
+ def process_query(question: str) -> Tuple[str, str, str]:
270
+ """Gradio 的處理函數"""
271
+ if not text_to_sql_system:
272
+ error_msg = "系統初始化失敗,無法處理請求。"
273
+ return error_msg, "失敗", error_msg
274
+
275
+ if not question.strip():
276
+ return "", "等待輸入", "請輸入您的問題。"
277
+
278
+ sql_result, status = text_to_sql_system.process_question(question)
279
+ log_output = "\n".join(text_to_sql_system.log_history)
280
+ return sql_result, status, log_output
281
+
282
+ # Gradio 介面佈局
283
+ with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能查詢系統") as demo:
284
+ gr.Markdown("# 📊 Text-to-SQL 智能查詢系統")
285
+ gr.Markdown("輸入您的自然語言問題,系統將自動轉換為 SQL 查詢語法。")
286
 
287
  with gr.Row():
288
+ with gr.Column(scale=2):
289
+ question_input = gr.Textbox(
290
+ lines=3,
291
+ label="💬 您的問題",
292
+ placeholder="例如:2024年每月完成了多少份報告?"
293
+ )
294
+ submit_btn = gr.Button("🚀 生成 SQL", variant="primary")
295
+ status_output = gr.Textbox(label="處理狀態", interactive=False)
296
+
297
+ with gr.Column(scale=3):
298
+ sql_output = gr.Code(label="🤖 生成的 SQL 查詢", language="sql")
299
 
300
+ with gr.Accordion("🔍 顯示詳細處理日誌", open=False):
301
+ log_output = gr.Textbox(lines=15, label="日誌", interactive=False)
 
 
302
 
303
+ # 優化的範例
304
  gr.Examples(
305
  examples=[
306
  "2024年每月完成多少份報告?",
 
338
  demo.launch(
339
  server_name="0.0.0.0",
340
  server_port=7860,
 
 
 
341
  )
342
  else:
343
  # 本地環境
344
+ print("🏠 在本地環境中啟動 ([http://127.0.0.1:7860](http://127.0.0.1:7860))...")
345
+ demo.launch()