Paul720810 commited on
Commit
7371ddd
·
verified ·
1 Parent(s): 17fd648

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -93
app.py CHANGED
@@ -46,27 +46,27 @@ def parse_sql_from_response(response_text: str) -> Optional[str]:
46
  """從模型輸出提取 SQL,增強版"""
47
  if not response_text:
48
  return None
49
-
50
  # 清理回應文本
51
  response_text = response_text.strip()
52
-
53
  # 1. 先找 ```sql ... ```
54
  match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
55
  if match:
56
  return match.group(1).strip()
57
-
58
  # 2. 找任何 ``` 包圍的內容
59
  match = re.search(r"```\s*\n?(.*?)\n?```", response_text, re.DOTALL)
60
  if match:
61
  sql_candidate = match.group(1).strip()
62
  if sql_candidate.upper().startswith('SELECT'):
63
  return sql_candidate
64
-
65
  # 3. 找 SQL 語句(更寬鬆的匹配)
66
  match = re.search(r"(SELECT\s+.*?;)", response_text, re.DOTALL | re.IGNORECASE)
67
  if match:
68
  return match.group(1).strip()
69
-
70
  # 4. 找沒有分號的 SQL
71
  match = re.search(r"(SELECT\s+.*?)(?=\n\n|\n```|$|\n[^,\s])", response_text, re.DOTALL | re.IGNORECASE)
72
  if match:
@@ -74,7 +74,7 @@ def parse_sql_from_response(response_text: str) -> Optional[str]:
74
  if not sql.endswith(';'):
75
  sql += ';'
76
  return sql
77
-
78
  # 5. 如果包含 SELECT,嘗試提取整行
79
  if 'SELECT' in response_text.upper():
80
  lines = response_text.split('\n')
@@ -84,7 +84,7 @@ def parse_sql_from_response(response_text: str) -> Optional[str]:
84
  if not line.endswith(';'):
85
  line += ';'
86
  return line
87
-
88
  return None
89
 
90
  # ==================== Text-to-SQL 核心類 ====================
@@ -113,7 +113,7 @@ class TextToSQLSystem:
113
  self._log("✅ 系統初始化完成")
114
  # 載入數據庫結構
115
  self.schema = self._load_schema()
116
-
117
  # 暫時添加:打印 schema 信息
118
  if self.schema:
119
  print("=" * 50)
@@ -125,7 +125,7 @@ class TextToSQLSystem:
125
  for col in columns[:5]: # 只顯示前5個
126
  print(f" - {col['name']} ({col['type']})")
127
  print("=" * 50)
128
-
129
  # in class TextToSQLSystem:
130
 
131
  def _load_gguf_model(self):
@@ -137,7 +137,7 @@ class TextToSQLSystem:
137
  filename=GGUF_FILENAME,
138
  repo_type="dataset"
139
  )
140
-
141
  # 使用一組更基礎、更穩定的參數來載入模型
142
  self.llm = Llama(
143
  model_path=model_path,
@@ -147,16 +147,16 @@ class TextToSQLSystem:
147
  verbose=False, # 設為 False 避免 llama.cpp 本身的日誌干擾
148
  n_gpu_layers=0 # 確認在 CPU 上運行
149
  )
150
-
151
  # 簡單測試模型是否能回應
152
  self.llm("你好", max_tokens=3)
153
  self._log("✅ GGUF 模型載入成功")
154
-
155
  except Exception as e:
156
  self._log(f"❌ GGUF 載入失敗: {e}", "ERROR")
157
  self._log("系統將無法生成 SQL。請檢查模型檔案或 llama-cpp-python 安裝。", "CRITICAL")
158
  self.llm = None
159
-
160
  def _try_gguf_loading(self):
161
  """嘗試載入 GGUF"""
162
  try:
@@ -165,7 +165,7 @@ class TextToSQLSystem:
165
  filename=GGUF_FILENAME,
166
  repo_type="dataset"
167
  )
168
-
169
  self.llm = Llama(
170
  model_path=model_path,
171
  n_ctx=512,
@@ -173,24 +173,24 @@ class TextToSQLSystem:
173
  verbose=False,
174
  n_gpu_layers=0
175
  )
176
-
177
  # 測試生成
178
  test_result = self.llm("SELECT", max_tokens=5)
179
  self._log("✅ GGUF 模型載入成功")
180
  return True
181
-
182
  except Exception as e:
183
  self._log(f"GGUF 載入失敗: {e}", "WARNING")
184
  return False
185
-
186
  def _load_transformers_model(self):
187
  """使用 Transformers 載入你的微調模型"""
188
  try:
189
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
190
  import torch
191
-
192
  self._log(f"載入 Transformers 模型: {FINETUNED_MODEL_PATH}")
193
-
194
  # 載入你的微調模型
195
  self.transformers_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL_PATH)
196
  self.transformers_model = AutoModelForCausalLM.from_pretrained(
@@ -199,7 +199,7 @@ class TextToSQLSystem:
199
  device_map="cpu", # 強制使用 CPU
200
  trust_remote_code=True # Qwen 模型可能需要
201
  )
202
-
203
  # 創建生成管道
204
  self.generation_pipeline = pipeline(
205
  "text-generation",
@@ -212,20 +212,20 @@ class TextToSQLSystem:
212
  top_p=0.9,
213
  pad_token_id=self.transformers_tokenizer.eos_token_id
214
  )
215
-
216
  self.llm = "transformers" # 標記使用 transformers
217
  self._log("✅ Transformers 模型��入成功")
218
-
219
  except Exception as e:
220
  self._log(f"❌ Transformers 載入也失敗: {e}", "ERROR")
221
  self.llm = None
222
-
223
  def huggingface_api_call(self, prompt: str) -> str:
224
  """調用 GGUF 模型,並加入詳細的原始輸出日誌"""
225
  if self.llm is None:
226
  self._log("模型未載入,返回 fallback SQL。", "ERROR")
227
  return self._generate_fallback_sql(prompt)
228
-
229
  try:
230
  output = self.llm(
231
  prompt,
@@ -236,9 +236,9 @@ class TextToSQLSystem:
236
  # --- 將 stop 參數加回來 ---
237
  stop=["```", ";", "\n\n", "</s>"],
238
  )
239
-
240
  self._log(f"🧠 模型原始輸出 (Raw Output): {output}", "DEBUG")
241
-
242
  if output and "choices" in output and len(output["choices"]) > 0:
243
  generated_text = output["choices"][0]["text"]
244
  self._log(f"📝 提取出的生成文本: {generated_text.strip()}", "DEBUG")
@@ -246,13 +246,13 @@ class TextToSQLSystem:
246
  else:
247
  self._log("❌ 模型的原始輸出格式不正確或為空。", "ERROR")
248
  return ""
249
-
250
  except Exception as e:
251
  self._log(f"❌ 模型生成過程中發生嚴重錯誤: {e}", "CRITICAL")
252
  import traceback
253
  self._log(traceback.format_exc(), "DEBUG")
254
  return ""
255
-
256
  def _load_gguf_model_fallback(self, model_path):
257
  """備用載入方式"""
258
  try:
@@ -286,7 +286,7 @@ class TextToSQLSystem:
286
  )
287
  with open(schema_path, "r", encoding="utf-8") as f:
288
  schema_data = json.load(f)
289
-
290
  # 添加調試信息
291
  self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格:")
292
  for table_name, columns in schema_data.items():
@@ -294,14 +294,14 @@ class TextToSQLSystem:
294
  # 顯示前3個欄位作為範例
295
  sample_cols = [col['name'] for col in columns[:3]]
296
  self._log(f" 範例欄位: {', '.join(sample_cols)}")
297
-
298
  self._log("✅ 數據庫結構載入完成")
299
  return schema_data
300
-
301
  except Exception as e:
302
  self._log(f"❌ 載入 schema 失敗: {e}", "ERROR")
303
  return {}
304
-
305
  # 也可以添加一個方法來檢查生成的 SQL 是否使用了正確的表格和欄位
306
  def _analyze_sql_correctness(self, sql: str) -> Dict:
307
  """分析 SQL 的正確性"""
@@ -312,15 +312,15 @@ class TextToSQLSystem:
312
  'invalid_columns': [],
313
  'suggestions': []
314
  }
315
-
316
  if not self.schema:
317
  return analysis
318
-
319
  # 提取 SQL 中的表格名稱
320
  table_pattern = r'FROM\s+(\w+)|JOIN\s+(\w+)'
321
  table_matches = re.findall(table_pattern, sql, re.IGNORECASE)
322
  used_tables = [match[0] or match[1] for match in table_matches]
323
-
324
  # 檢查表格是否存在
325
  valid_tables = list(self.schema.keys())
326
  for table in used_tables:
@@ -332,26 +332,26 @@ class TextToSQLSystem:
332
  for valid_table in valid_tables:
333
  if table.lower() in valid_table.lower() or valid_table.lower() in table.lower():
334
  analysis['suggestions'].append(f"{table} -> {valid_table}")
335
-
336
  # 提取欄位名稱(簡單版本)
337
  column_pattern = r'SELECT\s+(.*?)\s+FROM|WHERE\s+(\w+)\s*[=<>]|GROUP BY\s+(\w+)|ORDER BY\s+(\w+)'
338
  column_matches = re.findall(column_pattern, sql, re.IGNORECASE)
339
-
340
  return analysis
341
 
342
  def _encode_texts(self, texts):
343
  """編碼文本為嵌入向量"""
344
  if isinstance(texts, str):
345
  texts = [texts]
346
-
347
- inputs = self.embed_tokenizer(texts, padding=True, truncation=True,
348
  return_tensors="pt", max_length=512)
349
  if DEVICE == "cuda":
350
  inputs = {k: v.cuda() for k, v in inputs.items()}
351
-
352
  with torch.no_grad():
353
  outputs = self.embed_model(**inputs)
354
-
355
  # 使用平均池化
356
  embeddings = outputs.last_hidden_state.mean(dim=1)
357
  return embeddings.cpu()
@@ -360,28 +360,53 @@ class TextToSQLSystem:
360
  """載入數據集並建立 FAISS 索引"""
361
  try:
362
  dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  corpus = [item['messages'][0]['content'] for item in dataset]
364
  self._log(f"正在編碼 {len(corpus)} 個問題...")
365
-
366
  # 批量編碼
367
  embeddings_list = []
368
  batch_size = 32
369
-
370
  for i in range(0, len(corpus), batch_size):
371
  batch_texts = corpus[i:i+batch_size]
372
  batch_embeddings = self._encode_texts(batch_texts)
373
  embeddings_list.append(batch_embeddings)
374
  self._log(f"已編碼 {min(i+batch_size, len(corpus))}/{len(corpus)}")
375
-
376
  all_embeddings = torch.cat(embeddings_list, dim=0).numpy()
377
-
378
  # 建立 FAISS 索引
379
  index = faiss.IndexFlatIP(all_embeddings.shape[1])
380
  index.add(all_embeddings.astype('float32'))
381
-
382
  self._log("✅ 向量索引建立完成")
383
  return dataset, index
384
-
385
  except Exception as e:
386
  self._log(f"❌ 載入數據失敗: {e}", "ERROR")
387
  return None, None
@@ -390,7 +415,7 @@ class TextToSQLSystem:
390
  """根據實際 Schema 識別相關表格"""
391
  question_lower = question.lower()
392
  relevant_tables = []
393
-
394
  # 根據實際表格的關鍵詞映射
395
  keyword_to_table = {
396
  'TSR53SampleDescription': ['客戶', '買方', '申請', '發票對象', 'customer', 'invoice', 'sample'],
@@ -400,18 +425,18 @@ class TextToSQLSystem:
400
  'JobEventsLog': ['事件', '操作', '用戶', 'event', 'log', 'user'],
401
  'calendar_days': ['工作日', '假期', 'workday', 'holiday', 'calendar']
402
  }
403
-
404
  for table, keywords in keyword_to_table.items():
405
  if any(keyword in question_lower for keyword in keywords):
406
  relevant_tables.append(table)
407
-
408
  # 預設重要表格
409
  if not relevant_tables:
410
  if any(word in question_lower for word in ['客戶', '買家', '申請', '工作單', '數量']):
411
  return ['TSR53SampleDescription', 'JobsInProgress']
412
  else:
413
  return ['JobTimeline', 'TSR53SampleDescription']
414
-
415
  return relevant_tables[:3] # 最多返回3個相關表格
416
 
417
  # 請將這整個函數複製到您的 TextToSQLSystem class 內部
@@ -422,7 +447,7 @@ class TextToSQLSystem:
422
  """
423
  if not self.schema:
424
  return "No schema available.\n"
425
-
426
  actual_table_names_map = {name.lower(): name for name in self.schema.keys()}
427
  real_table_names = []
428
  for table in table_names:
@@ -453,7 +478,7 @@ class TextToSQLSystem:
453
  else:
454
  cols_str.append(f"{col_name} ({col_type})")
455
  formatted += f"Columns: {', '.join(cols_str)}\n\n"
456
-
457
  return formatted.strip()
458
 
459
 
@@ -470,14 +495,14 @@ class TextToSQLSystem:
470
  返回一個元組 (SQL字符串或None, 狀態消息)。
471
  """
472
  q_lower = question.lower()
473
-
474
  # ==============================================================================
475
  # 第一層:高價值意圖識別與模板覆寫 (Intent Recognition & Templating)
476
  # ==============================================================================
477
-
478
  # --- 預先檢測所有可能的意圖和實體 ---
479
  job_no_match = re.search(r"(?:工單|jobno)\s*'\"?([A-Z]{2,3}\d+)'\"?", question, re.IGNORECASE)
480
-
481
  entity_match_data = None
482
  ENTITY_TO_COLUMN_MAP = {
483
  '申請廠商': 'sd.ApplicantName', '申請方': 'sd.ApplicantName', 'applicant': 'sd.ApplicantName',
@@ -491,7 +516,7 @@ class TextToSQLSystem:
491
  if match:
492
  entity_match_data = {"type": keyword, "name": match.group(1).strip(), "column": column}
493
  break
494
-
495
  lab_group_match_data = None
496
  LAB_GROUP_MAP = {'A':'TA','B':'TB','C':'TC','D':'TD','E':'TE','Y':'TY','TA':'TA','TB':'TB','TC':'TC','TD':'TD','TE':'TE','TY':'TY','WC':'WC','EO':'EO','GCI':'GCI','GCO':'GCO','MI':'MI'}
497
  lab_group_match = re.findall(r"([A-Z]+)\s*組", question, re.IGNORECASE)
@@ -528,7 +553,7 @@ class TextToSQLSystem:
528
  self._log(f"🔄 檢測到查詢【{entity_type} '{entity_name}' 在 {year} 年的總業績】意圖,啟用模板。", "INFO")
529
  template_sql = f"WITH JobTotalAmount AS (SELECT JobNo, SUM(LocalAmount) AS TotalAmount FROM (SELECT DISTINCT JobNo, InvoiceCreditNoteNo, LocalAmount FROM TSR53Invoice) GROUP BY JobNo) SELECT SUM(jta.TotalAmount) AS total_revenue FROM TSR53SampleDescription AS sd JOIN JobTotalAmount AS jta ON sd.JobNo = jta.JobNo WHERE {column_name} LIKE '%{entity_name}%' AND strftime('%Y', sd.FirstReportAuthorizedDate) = '{year}';"
530
  return self._finalize_sql(template_sql, f"模板覆寫: 查詢 {entity_type}='{entity_name}' ({year}年) 的總業績")
531
-
532
  if not entity_match_data and any(kw in q_lower for kw in ['業績', '營收', '金額', 'sales', 'revenue']):
533
  year_match, month_match = re.search(r'(\d{4})\s*年?', question), re.search(r'(\d{1,2})\s*月', question)
534
  time_condition, time_log = "", "總"
@@ -571,17 +596,17 @@ class TextToSQLSystem:
571
  # 第二層:常規修正流程 (Fallback Corrections)
572
  # ==============================================================================
573
  self._log("未觸發任何模板,嘗試解析並修正 AI 輸出...", "INFO")
574
-
575
  parsed_sql = parse_sql_from_response(raw_response)
576
  if not parsed_sql:
577
  self._log(f"❌ 未能從模型回應中解析出任何 SQL。原始回應: {raw_response}", "ERROR")
578
  return None, f"無法解析SQL。原始回應:\n{raw_response}"
579
 
580
  self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG")
581
-
582
  fixed_sql = " " + parsed_sql.strip() + " "
583
  fixes_applied_fallback = []
584
-
585
  dialect_corrections = {
586
  r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)",
587
  r"(strftime\('%Y',\s*[^)]+\))\s*=\s*(\d{4})": r"\1 = '\2'",
@@ -632,47 +657,52 @@ class TextToSQLSystem:
632
  """使用 FAISS 快速檢索相似問題"""
633
  if self.faiss_index is None or self.dataset is None:
634
  return []
635
-
636
  try:
637
  # 編碼問題
638
  q_embedding = self._encode_texts([question]).numpy().astype('float32')
639
-
640
  # FAISS 搜索
641
  distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset)))
642
-
643
  results = []
644
  seen_questions = set()
645
-
646
  for i, idx in enumerate(indices[0]):
647
  if len(results) >= top_k:
648
  break
649
-
650
  # 修復:將 numpy.int64 轉換為 Python int
651
  idx = int(idx) # ← 添加這行轉換
652
-
653
  if idx >= len(self.dataset): # 確保索引有效
654
  continue
655
-
656
  item = self.dataset[idx]
657
- q_content = item['messages'][0]['content']
658
- a_content = item['messages'][1]['content']
659
-
 
 
 
 
 
660
  # 提取純淨問題
661
  clean_q = re.sub(r"以下是一個SQL查詢任務:\s*指令:\s*", "", q_content).strip()
662
  if clean_q in seen_questions:
663
  continue
664
-
665
  seen_questions.add(clean_q)
666
  sql = parse_sql_from_response(a_content) or "無法解析範例SQL"
667
-
668
  results.append({
669
  "similarity": float(distances[0][i]),
670
  "question": clean_q,
671
  "sql": sql
672
  })
673
-
674
  return results
675
-
676
  except Exception as e:
677
  self._log(f"❌ 檢索失敗: {e}", "ERROR")
678
  return []
@@ -684,7 +714,7 @@ class TextToSQLSystem:
684
  建立一個高度結構化、以任務為導向的提示詞,使用清晰的標題分隔符。
685
  """
686
  relevant_tables = self._identify_relevant_tables(user_q)
687
-
688
  # 使用我們新的、更簡單的 schema 格式化函數
689
  schema_str = self._format_relevant_schema(relevant_tables)
690
 
@@ -721,7 +751,7 @@ SQL:
721
  def _generate_fallback_sql(self, prompt: str) -> str:
722
  """當模型不可用時的備用 SQL 生成"""
723
  prompt_lower = prompt.lower()
724
-
725
  # 簡單的關鍵詞匹配生成基本 SQL
726
  if "統計" in prompt or "數量" in prompt or "多少" in prompt:
727
  if "月" in prompt:
@@ -730,13 +760,13 @@ SQL:
730
  return "SELECT applicant, COUNT(*) as count FROM tsr53sampledescription GROUP BY applicant ORDER BY count DESC;"
731
  else:
732
  return "SELECT COUNT(*) as total_count FROM jobtimeline WHERE completed_time IS NOT NULL;"
733
-
734
  elif "金額" in prompt or "總額" in prompt:
735
  return "SELECT SUM(amount) as total_amount FROM tsr53invoice;"
736
-
737
  elif "評級" in prompt or "pass" in prompt_lower or "fail" in prompt_lower:
738
  return "SELECT rating, COUNT(*) as count FROM tsr53sampledescription GROUP BY rating;"
739
-
740
  else:
741
  return "SELECT * FROM jobtimeline LIMIT 10;"
742
 
@@ -745,22 +775,22 @@ SQL:
745
  try:
746
  if not os.path.exists(model_path):
747
  return False
748
-
749
  # 檢查檔案大小(至少應該有幾MB)
750
  file_size = os.path.getsize(model_path)
751
  if file_size < 10 * 1024 * 1024: # 小於 10MB 可能有問題
752
  return False
753
-
754
  # 檢查 GGUF 檔案頭部
755
  with open(model_path, 'rb') as f:
756
  header = f.read(8)
757
  if not header.startswith(b'GGUF'):
758
  return False
759
-
760
  return True
761
  except Exception:
762
  return False
763
-
764
  # in class TextToSQLSystem:
765
 
766
  def process_question(self, question: str) -> Tuple[str, str]:
@@ -769,7 +799,7 @@ SQL:
769
  if question in self.query_cache:
770
  self._log("⚡ 使用緩存結果")
771
  return self.query_cache[question]
772
-
773
  self.log_history = []
774
  self._log(f"⏰ 處理問題: {question}")
775
 
@@ -788,12 +818,12 @@ SQL:
788
 
789
  # 4. **新的核心步驟**: 呼叫決策引擎來生成最終 SQL
790
  final_sql, status_message = self._validate_and_fix_sql(question, response)
791
-
792
  if final_sql:
793
  result = (final_sql, status_message)
794
  else:
795
  result = (status_message, "生成失敗")
796
-
797
  # 緩存結果
798
  self.query_cache[question] = result
799
  return result
@@ -804,10 +834,10 @@ text_to_sql_system = TextToSQLSystem()
804
  def process_query(q: str):
805
  if not q.strip():
806
  return "", "等待輸入", "請輸入問題"
807
-
808
  sql, status = text_to_sql_system.process_question(q)
809
  logs = "\n".join(text_to_sql_system.log_history[-10:]) # 只顯示最後10條日誌
810
-
811
  return sql, status, logs
812
 
813
  # 範例問題
@@ -822,19 +852,19 @@ examples = [
822
  with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
823
  gr.Markdown("# ⚡ Text-to-SQL 智能助手")
824
  gr.Markdown("輸入自然語言問題,自動生成SQL查詢語句")
825
-
826
  with gr.Row():
827
  with gr.Column(scale=2):
828
  inp = gr.Textbox(lines=3, label="💬 您的問題", placeholder="例如:2024年每月完成多少份報告?")
829
  btn = gr.Button("🚀 生成 SQL", variant="primary")
830
  status = gr.Textbox(label="狀態", interactive=False)
831
-
832
  with gr.Column(scale=3):
833
  sql_out = gr.Code(label="🤖 生成的 SQL", language="sql", lines=8)
834
-
835
  with gr.Accordion("📋 處理日誌", open=False):
836
  logs = gr.Textbox(lines=8, label="日誌", interactive=False)
837
-
838
  # 範例區
839
  gr.Examples(
840
  examples=examples,
 
46
  """從模型輸出提取 SQL,增強版"""
47
  if not response_text:
48
  return None
49
+
50
  # 清理回應文本
51
  response_text = response_text.strip()
52
+
53
  # 1. 先找 ```sql ... ```
54
  match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
55
  if match:
56
  return match.group(1).strip()
57
+
58
  # 2. 找任何 ``` 包圍的內容
59
  match = re.search(r"```\s*\n?(.*?)\n?```", response_text, re.DOTALL)
60
  if match:
61
  sql_candidate = match.group(1).strip()
62
  if sql_candidate.upper().startswith('SELECT'):
63
  return sql_candidate
64
+
65
  # 3. 找 SQL 語句(更寬鬆的匹配)
66
  match = re.search(r"(SELECT\s+.*?;)", response_text, re.DOTALL | re.IGNORECASE)
67
  if match:
68
  return match.group(1).strip()
69
+
70
  # 4. 找沒有分號的 SQL
71
  match = re.search(r"(SELECT\s+.*?)(?=\n\n|\n```|$|\n[^,\s])", response_text, re.DOTALL | re.IGNORECASE)
72
  if match:
 
74
  if not sql.endswith(';'):
75
  sql += ';'
76
  return sql
77
+
78
  # 5. 如果包含 SELECT,嘗試提取整行
79
  if 'SELECT' in response_text.upper():
80
  lines = response_text.split('\n')
 
84
  if not line.endswith(';'):
85
  line += ';'
86
  return line
87
+
88
  return None
89
 
90
  # ==================== Text-to-SQL 核心類 ====================
 
113
  self._log("✅ 系統初始化完成")
114
  # 載入數據庫結構
115
  self.schema = self._load_schema()
116
+
117
  # 暫時添加:打印 schema 信息
118
  if self.schema:
119
  print("=" * 50)
 
125
  for col in columns[:5]: # 只顯示前5個
126
  print(f" - {col['name']} ({col['type']})")
127
  print("=" * 50)
128
+
129
  # in class TextToSQLSystem:
130
 
131
  def _load_gguf_model(self):
 
137
  filename=GGUF_FILENAME,
138
  repo_type="dataset"
139
  )
140
+
141
  # 使用一組更基礎、更穩定的參數來載入模型
142
  self.llm = Llama(
143
  model_path=model_path,
 
147
  verbose=False, # 設為 False 避免 llama.cpp 本身的日誌干擾
148
  n_gpu_layers=0 # 確認在 CPU 上運行
149
  )
150
+
151
  # 簡單測試模型是否能回應
152
  self.llm("你好", max_tokens=3)
153
  self._log("✅ GGUF 模型載入成功")
154
+
155
  except Exception as e:
156
  self._log(f"❌ GGUF 載入失敗: {e}", "ERROR")
157
  self._log("系統將無法生成 SQL。請檢查模型檔案或 llama-cpp-python 安裝。", "CRITICAL")
158
  self.llm = None
159
+
160
  def _try_gguf_loading(self):
161
  """嘗試載入 GGUF"""
162
  try:
 
165
  filename=GGUF_FILENAME,
166
  repo_type="dataset"
167
  )
168
+
169
  self.llm = Llama(
170
  model_path=model_path,
171
  n_ctx=512,
 
173
  verbose=False,
174
  n_gpu_layers=0
175
  )
176
+
177
  # 測試生成
178
  test_result = self.llm("SELECT", max_tokens=5)
179
  self._log("✅ GGUF 模型載入成功")
180
  return True
181
+
182
  except Exception as e:
183
  self._log(f"GGUF 載入失敗: {e}", "WARNING")
184
  return False
185
+
186
  def _load_transformers_model(self):
187
  """使用 Transformers 載入你的微調模型"""
188
  try:
189
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
190
  import torch
191
+
192
  self._log(f"載入 Transformers 模型: {FINETUNED_MODEL_PATH}")
193
+
194
  # 載入你的微調模型
195
  self.transformers_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL_PATH)
196
  self.transformers_model = AutoModelForCausalLM.from_pretrained(
 
199
  device_map="cpu", # 強制使用 CPU
200
  trust_remote_code=True # Qwen 模型可能需要
201
  )
202
+
203
  # 創建生成管道
204
  self.generation_pipeline = pipeline(
205
  "text-generation",
 
212
  top_p=0.9,
213
  pad_token_id=self.transformers_tokenizer.eos_token_id
214
  )
215
+
216
  self.llm = "transformers" # 標記使用 transformers
217
  self._log("✅ Transformers 模型��入成功")
218
+
219
  except Exception as e:
220
  self._log(f"❌ Transformers 載入也失敗: {e}", "ERROR")
221
  self.llm = None
222
+
223
  def huggingface_api_call(self, prompt: str) -> str:
224
  """調用 GGUF 模型,並加入詳細的原始輸出日誌"""
225
  if self.llm is None:
226
  self._log("模型未載入,返回 fallback SQL。", "ERROR")
227
  return self._generate_fallback_sql(prompt)
228
+
229
  try:
230
  output = self.llm(
231
  prompt,
 
236
  # --- 將 stop 參數加回來 ---
237
  stop=["```", ";", "\n\n", "</s>"],
238
  )
239
+
240
  self._log(f"🧠 模型原始輸出 (Raw Output): {output}", "DEBUG")
241
+
242
  if output and "choices" in output and len(output["choices"]) > 0:
243
  generated_text = output["choices"][0]["text"]
244
  self._log(f"📝 提取出的生成文本: {generated_text.strip()}", "DEBUG")
 
246
  else:
247
  self._log("❌ 模型的原始輸出格式不正確或為空。", "ERROR")
248
  return ""
249
+
250
  except Exception as e:
251
  self._log(f"❌ 模型生成過程中發生嚴重錯誤: {e}", "CRITICAL")
252
  import traceback
253
  self._log(traceback.format_exc(), "DEBUG")
254
  return ""
255
+
256
  def _load_gguf_model_fallback(self, model_path):
257
  """備用載入方式"""
258
  try:
 
286
  )
287
  with open(schema_path, "r", encoding="utf-8") as f:
288
  schema_data = json.load(f)
289
+
290
  # 添加調試信息
291
  self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格:")
292
  for table_name, columns in schema_data.items():
 
294
  # 顯示前3個欄位作為範例
295
  sample_cols = [col['name'] for col in columns[:3]]
296
  self._log(f" 範例欄位: {', '.join(sample_cols)}")
297
+
298
  self._log("✅ 數據庫結構載入完成")
299
  return schema_data
300
+
301
  except Exception as e:
302
  self._log(f"❌ 載入 schema 失敗: {e}", "ERROR")
303
  return {}
304
+
305
  # 也可以添加一個方法來檢查生成的 SQL 是否使用了正確的表格和欄位
306
  def _analyze_sql_correctness(self, sql: str) -> Dict:
307
  """分析 SQL 的正確性"""
 
312
  'invalid_columns': [],
313
  'suggestions': []
314
  }
315
+
316
  if not self.schema:
317
  return analysis
318
+
319
  # 提取 SQL 中的表格名稱
320
  table_pattern = r'FROM\s+(\w+)|JOIN\s+(\w+)'
321
  table_matches = re.findall(table_pattern, sql, re.IGNORECASE)
322
  used_tables = [match[0] or match[1] for match in table_matches]
323
+
324
  # 檢查表格是否存在
325
  valid_tables = list(self.schema.keys())
326
  for table in used_tables:
 
332
  for valid_table in valid_tables:
333
  if table.lower() in valid_table.lower() or valid_table.lower() in table.lower():
334
  analysis['suggestions'].append(f"{table} -> {valid_table}")
335
+
336
  # 提取欄位名稱(簡單版本)
337
  column_pattern = r'SELECT\s+(.*?)\s+FROM|WHERE\s+(\w+)\s*[=<>]|GROUP BY\s+(\w+)|ORDER BY\s+(\w+)'
338
  column_matches = re.findall(column_pattern, sql, re.IGNORECASE)
339
+
340
  return analysis
341
 
342
  def _encode_texts(self, texts):
343
  """編碼文本為嵌入向量"""
344
  if isinstance(texts, str):
345
  texts = [texts]
346
+
347
+ inputs = self.embed_tokenizer(texts, padding=True, truncation=True,
348
  return_tensors="pt", max_length=512)
349
  if DEVICE == "cuda":
350
  inputs = {k: v.cuda() for k, v in inputs.items()}
351
+
352
  with torch.no_grad():
353
  outputs = self.embed_model(**inputs)
354
+
355
  # 使用平均池化
356
  embeddings = outputs.last_hidden_state.mean(dim=1)
357
  return embeddings.cpu()
 
360
  """載入數據集並建立 FAISS 索引"""
361
  try:
362
  dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
363
+
364
+ # 先過濾不完整樣本,避免 messages 長度不足導致索引或檢索報錯
365
+ try:
366
+ original_count = len(dataset)
367
+ except Exception:
368
+ original_count = None
369
+
370
+ dataset = dataset.filter(
371
+ lambda ex: isinstance(ex.get("messages"), list)
372
+ and len(ex["messages"]) >= 2
373
+ and all(
374
+ isinstance(m.get("content"), str) and m.get("content") and m["content"].strip()
375
+ for m in ex["messages"][:2]
376
+ )
377
+ )
378
+
379
+ if original_count is not None:
380
+ self._log(
381
+ f"資料集清理: 原始 {original_count} 筆, 過濾後 {len(dataset)} 筆, 移除 {original_count - len(dataset)} 筆"
382
+ )
383
+
384
+ if len(dataset) == 0:
385
+ self._log("清理後資料集為空,無法建立索引。", "ERROR")
386
+ return None, None
387
+
388
  corpus = [item['messages'][0]['content'] for item in dataset]
389
  self._log(f"正在編碼 {len(corpus)} 個問題...")
390
+
391
  # 批量編碼
392
  embeddings_list = []
393
  batch_size = 32
394
+
395
  for i in range(0, len(corpus), batch_size):
396
  batch_texts = corpus[i:i+batch_size]
397
  batch_embeddings = self._encode_texts(batch_texts)
398
  embeddings_list.append(batch_embeddings)
399
  self._log(f"已編碼 {min(i+batch_size, len(corpus))}/{len(corpus)}")
400
+
401
  all_embeddings = torch.cat(embeddings_list, dim=0).numpy()
402
+
403
  # 建立 FAISS 索引
404
  index = faiss.IndexFlatIP(all_embeddings.shape[1])
405
  index.add(all_embeddings.astype('float32'))
406
+
407
  self._log("✅ 向量索引建立完成")
408
  return dataset, index
409
+
410
  except Exception as e:
411
  self._log(f"❌ 載入數據失敗: {e}", "ERROR")
412
  return None, None
 
415
  """根據實際 Schema 識別相關表格"""
416
  question_lower = question.lower()
417
  relevant_tables = []
418
+
419
  # 根據實際表格的關鍵詞映射
420
  keyword_to_table = {
421
  'TSR53SampleDescription': ['客戶', '買方', '申請', '發票對象', 'customer', 'invoice', 'sample'],
 
425
  'JobEventsLog': ['事件', '操作', '用戶', 'event', 'log', 'user'],
426
  'calendar_days': ['工作日', '假期', 'workday', 'holiday', 'calendar']
427
  }
428
+
429
  for table, keywords in keyword_to_table.items():
430
  if any(keyword in question_lower for keyword in keywords):
431
  relevant_tables.append(table)
432
+
433
  # 預設重要表格
434
  if not relevant_tables:
435
  if any(word in question_lower for word in ['客戶', '買家', '申請', '工作單', '數量']):
436
  return ['TSR53SampleDescription', 'JobsInProgress']
437
  else:
438
  return ['JobTimeline', 'TSR53SampleDescription']
439
+
440
  return relevant_tables[:3] # 最多返回3個相關表格
441
 
442
  # 請將這整個函數複製到您的 TextToSQLSystem class 內部
 
447
  """
448
  if not self.schema:
449
  return "No schema available.\n"
450
+
451
  actual_table_names_map = {name.lower(): name for name in self.schema.keys()}
452
  real_table_names = []
453
  for table in table_names:
 
478
  else:
479
  cols_str.append(f"{col_name} ({col_type})")
480
  formatted += f"Columns: {', '.join(cols_str)}\n\n"
481
+
482
  return formatted.strip()
483
 
484
 
 
495
  返回一個元組 (SQL字符串或None, 狀態消息)。
496
  """
497
  q_lower = question.lower()
498
+
499
  # ==============================================================================
500
  # 第一層:高價值意圖識別與模板覆寫 (Intent Recognition & Templating)
501
  # ==============================================================================
502
+
503
  # --- 預先檢測所有可能的意圖和實體 ---
504
  job_no_match = re.search(r"(?:工單|jobno)\s*'\"?([A-Z]{2,3}\d+)'\"?", question, re.IGNORECASE)
505
+
506
  entity_match_data = None
507
  ENTITY_TO_COLUMN_MAP = {
508
  '申請廠商': 'sd.ApplicantName', '申請方': 'sd.ApplicantName', 'applicant': 'sd.ApplicantName',
 
516
  if match:
517
  entity_match_data = {"type": keyword, "name": match.group(1).strip(), "column": column}
518
  break
519
+
520
  lab_group_match_data = None
521
  LAB_GROUP_MAP = {'A':'TA','B':'TB','C':'TC','D':'TD','E':'TE','Y':'TY','TA':'TA','TB':'TB','TC':'TC','TD':'TD','TE':'TE','TY':'TY','WC':'WC','EO':'EO','GCI':'GCI','GCO':'GCO','MI':'MI'}
522
  lab_group_match = re.findall(r"([A-Z]+)\s*組", question, re.IGNORECASE)
 
553
  self._log(f"🔄 檢測到查詢【{entity_type} '{entity_name}' 在 {year} 年的總業績】意圖,啟用模板。", "INFO")
554
  template_sql = f"WITH JobTotalAmount AS (SELECT JobNo, SUM(LocalAmount) AS TotalAmount FROM (SELECT DISTINCT JobNo, InvoiceCreditNoteNo, LocalAmount FROM TSR53Invoice) GROUP BY JobNo) SELECT SUM(jta.TotalAmount) AS total_revenue FROM TSR53SampleDescription AS sd JOIN JobTotalAmount AS jta ON sd.JobNo = jta.JobNo WHERE {column_name} LIKE '%{entity_name}%' AND strftime('%Y', sd.FirstReportAuthorizedDate) = '{year}';"
555
  return self._finalize_sql(template_sql, f"模板覆寫: 查詢 {entity_type}='{entity_name}' ({year}年) 的總業績")
556
+
557
  if not entity_match_data and any(kw in q_lower for kw in ['業績', '營收', '金額', 'sales', 'revenue']):
558
  year_match, month_match = re.search(r'(\d{4})\s*年?', question), re.search(r'(\d{1,2})\s*月', question)
559
  time_condition, time_log = "", "總"
 
596
  # 第二層:常規修正流程 (Fallback Corrections)
597
  # ==============================================================================
598
  self._log("未觸發任何模板,嘗試解析並修正 AI 輸出...", "INFO")
599
+
600
  parsed_sql = parse_sql_from_response(raw_response)
601
  if not parsed_sql:
602
  self._log(f"❌ 未能從模型回應中解析出任何 SQL。原始回應: {raw_response}", "ERROR")
603
  return None, f"無法解析SQL。原始回應:\n{raw_response}"
604
 
605
  self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG")
606
+
607
  fixed_sql = " " + parsed_sql.strip() + " "
608
  fixes_applied_fallback = []
609
+
610
  dialect_corrections = {
611
  r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)",
612
  r"(strftime\('%Y',\s*[^)]+\))\s*=\s*(\d{4})": r"\1 = '\2'",
 
657
  """使用 FAISS 快速檢索相似問題"""
658
  if self.faiss_index is None or self.dataset is None:
659
  return []
660
+
661
  try:
662
  # 編碼問題
663
  q_embedding = self._encode_texts([question]).numpy().astype('float32')
664
+
665
  # FAISS 搜索
666
  distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset)))
667
+
668
  results = []
669
  seen_questions = set()
670
+
671
  for i, idx in enumerate(indices[0]):
672
  if len(results) >= top_k:
673
  break
674
+
675
  # 修復:將 numpy.int64 轉換為 Python int
676
  idx = int(idx) # ← 添加這行轉換
677
+
678
  if idx >= len(self.dataset): # 確保索引有效
679
  continue
680
+
681
  item = self.dataset[idx]
682
+ # 防呆:若樣本不完整則跳過
683
+ if not isinstance(item.get('messages'), list) or len(item['messages']) < 2:
684
+ continue
685
+ q_content = (item['messages'][0].get('content') or '').strip()
686
+ a_content = (item['messages'][1].get('content') or '').strip()
687
+ if not q_content or not a_content:
688
+ continue
689
+
690
  # 提取純淨問題
691
  clean_q = re.sub(r"以下是一個SQL查詢任務:\s*指令:\s*", "", q_content).strip()
692
  if clean_q in seen_questions:
693
  continue
694
+
695
  seen_questions.add(clean_q)
696
  sql = parse_sql_from_response(a_content) or "無法解析範例SQL"
697
+
698
  results.append({
699
  "similarity": float(distances[0][i]),
700
  "question": clean_q,
701
  "sql": sql
702
  })
703
+
704
  return results
705
+
706
  except Exception as e:
707
  self._log(f"❌ 檢索失敗: {e}", "ERROR")
708
  return []
 
714
  建立一個高度結構化、以任務為導向的提示詞,使用清晰的標題分隔符。
715
  """
716
  relevant_tables = self._identify_relevant_tables(user_q)
717
+
718
  # 使用我們新的、更簡單的 schema 格式化函數
719
  schema_str = self._format_relevant_schema(relevant_tables)
720
 
 
751
  def _generate_fallback_sql(self, prompt: str) -> str:
752
  """當模型不可用時的備用 SQL 生成"""
753
  prompt_lower = prompt.lower()
754
+
755
  # 簡單的關鍵詞匹配生成基本 SQL
756
  if "統計" in prompt or "數量" in prompt or "多少" in prompt:
757
  if "月" in prompt:
 
760
  return "SELECT applicant, COUNT(*) as count FROM tsr53sampledescription GROUP BY applicant ORDER BY count DESC;"
761
  else:
762
  return "SELECT COUNT(*) as total_count FROM jobtimeline WHERE completed_time IS NOT NULL;"
763
+
764
  elif "金額" in prompt or "總額" in prompt:
765
  return "SELECT SUM(amount) as total_amount FROM tsr53invoice;"
766
+
767
  elif "評級" in prompt or "pass" in prompt_lower or "fail" in prompt_lower:
768
  return "SELECT rating, COUNT(*) as count FROM tsr53sampledescription GROUP BY rating;"
769
+
770
  else:
771
  return "SELECT * FROM jobtimeline LIMIT 10;"
772
 
 
775
  try:
776
  if not os.path.exists(model_path):
777
  return False
778
+
779
  # 檢查檔案大小(至少應該有幾MB)
780
  file_size = os.path.getsize(model_path)
781
  if file_size < 10 * 1024 * 1024: # 小於 10MB 可能有問題
782
  return False
783
+
784
  # 檢查 GGUF 檔案頭部
785
  with open(model_path, 'rb') as f:
786
  header = f.read(8)
787
  if not header.startswith(b'GGUF'):
788
  return False
789
+
790
  return True
791
  except Exception:
792
  return False
793
+
794
  # in class TextToSQLSystem:
795
 
796
  def process_question(self, question: str) -> Tuple[str, str]:
 
799
  if question in self.query_cache:
800
  self._log("⚡ 使用緩存結果")
801
  return self.query_cache[question]
802
+
803
  self.log_history = []
804
  self._log(f"⏰ 處理問題: {question}")
805
 
 
818
 
819
  # 4. **新的核心步驟**: 呼叫決策引擎來生成最終 SQL
820
  final_sql, status_message = self._validate_and_fix_sql(question, response)
821
+
822
  if final_sql:
823
  result = (final_sql, status_message)
824
  else:
825
  result = (status_message, "生成失敗")
826
+
827
  # 緩存結果
828
  self.query_cache[question] = result
829
  return result
 
834
  def process_query(q: str):
835
  if not q.strip():
836
  return "", "等待輸入", "請輸入問題"
837
+
838
  sql, status = text_to_sql_system.process_question(q)
839
  logs = "\n".join(text_to_sql_system.log_history[-10:]) # 只顯示最後10條日誌
840
+
841
  return sql, status, logs
842
 
843
  # 範例問題
 
852
  with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
853
  gr.Markdown("# ⚡ Text-to-SQL 智能助手")
854
  gr.Markdown("輸入自然語言問題,自動生成SQL查詢語句")
855
+
856
  with gr.Row():
857
  with gr.Column(scale=2):
858
  inp = gr.Textbox(lines=3, label="💬 您的問題", placeholder="例如:2024年每月完成多少份報告?")
859
  btn = gr.Button("🚀 生成 SQL", variant="primary")
860
  status = gr.Textbox(label="狀態", interactive=False)
861
+
862
  with gr.Column(scale=3):
863
  sql_out = gr.Code(label="🤖 生成的 SQL", language="sql", lines=8)
864
+
865
  with gr.Accordion("📋 處理日誌", open=False):
866
  logs = gr.Textbox(lines=8, label="日誌", interactive=False)
867
+
868
  # 範例區
869
  gr.Examples(
870
  examples=examples,