Paul720810 commited on
Commit
2b8ddf5
·
verified ·
1 Parent(s): f80782b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -146
app.py CHANGED
@@ -1,6 +1,6 @@
1
  # ==============================================================================
2
- # Text-to-SQL 智能助手 - Hugging Face CPU 最终版 v6
3
- # (融合模板引擎 + 强化 Prompt + 修复所有 Bug)
4
  # ==============================================================================
5
  import gradio as gr
6
  import os
@@ -22,57 +22,32 @@ from transformers import AutoModel, AutoTokenizer
22
  import torch.nn.functional as F
23
 
24
  # ==================== 配置參數 ====================
25
- # --- Hugging Face CPU 部署配置 ---
26
  GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q8_0.gguf"
27
- N_GPU_LAYERS = 0 # 在 Hugging Face CPU 环境下设置为 0
28
-
29
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
30
  GGUF_REPO_ID = "Paul720810/gguf-models"
31
  FEW_SHOT_EXAMPLES_COUNT = 1
32
- DEVICE = "cuda" if torch.cuda.is_available() and N_GPU_LAYERS != 0 else "cpu"
33
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
34
-
35
  TEMP_DIR = tempfile.gettempdir()
36
  os.makedirs(os.path.join(TEMP_DIR, 'text_to_sql_cache'), exist_ok=True)
37
 
38
  print("=" * 60)
39
- print("🤖 Text-to-SQL 智能助手 v6.0 (Hugging Face CPU 版)...")
40
  print(f"🚀 模型: {GGUF_FILENAME}")
41
  print(f"💻 設備: {DEVICE} (GPU Layers: {N_GPU_LAYERS})")
42
  print("=" * 60)
43
 
44
  # ==================== 工具函數 ====================
45
- def get_current_time():
46
- return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
47
-
48
  def format_log(message: str, level: str = "INFO") -> str:
49
- log_entry = f"[{get_current_time()}] [{level.upper()}] {message}"
50
- print(log_entry)
51
- return log_entry
52
-
53
  def parse_sql_from_response(response_text: str) -> Optional[str]:
54
  if not response_text: return None
55
  response_text = response_text.strip()
56
  match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
57
  if match: return match.group(1).strip()
58
- match = re.search(r"```\s*\n?(.*?)\n?```", response_text, re.DOTALL)
59
- if match:
60
- sql_candidate = match.group(1).strip()
61
- if sql_candidate.upper().startswith('SELECT'): return sql_candidate
62
- match = re.search(r"(SELECT\s+.*?;)", response_text, re.DOTALL | re.IGNORECASE)
63
- if match: return match.group(1).strip()
64
- match = re.search(r"(SELECT\s+.*?)(?=\n\n|\n```|$|\n[^,\s])", response_text, re.DOTALL | re.IGNORECASE)
65
- if match:
66
- sql = match.group(1).strip()
67
- if not sql.endswith(';'): sql += ';'
68
- return sql
69
- if 'SELECT' in response_text.upper():
70
- for line in response_text.split('\n'):
71
- line = line.strip()
72
- if line.upper().startswith('SELECT'):
73
- if not line.endswith(';'): line += ';'
74
- return line
75
- return None
76
 
77
  # ==================== Text-to-SQL 核心類 ====================
78
  class TextToSQLSystem:
@@ -84,21 +59,15 @@ class TextToSQLSystem:
84
  self._log(f"載入嵌入模型: {embed_model_name}")
85
  self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
86
  self.embed_model = AutoModel.from_pretrained(embed_model_name)
87
- if DEVICE == "cuda":
88
- self.embed_model.to(DEVICE)
89
-
90
  self.schema = self._load_schema()
91
  self.dataset, self.faiss_index = self._load_and_index_dataset()
92
  self._load_gguf_model()
93
  self._log("✅ 系統初始化完成")
94
  except Exception as e:
95
- self._log(f"❌ 系統初始化過程中發生嚴重錯誤: {e}", "CRITICAL")
96
- self._log(traceback.format_exc(), "DEBUG")
97
- self.llm = None
98
-
99
- def _log(self, message: str, level: str = "INFO"):
100
- self.log_history.append(format_log(message, level))
101
 
 
 
102
  def _load_gguf_model(self):
103
  try:
104
  model_path = hf_hub_download(repo_id=GGUF_REPO_ID, filename=GGUF_FILENAME, repo_type="dataset", cache_dir=TEMP_DIR)
@@ -107,54 +76,47 @@ class TextToSQLSystem:
107
  self.llm = Llama(model_path=model_path, n_ctx=2048, n_threads=4, n_batch=512, verbose=False, n_gpu_layers=N_GPU_LAYERS)
108
  self._log("✅ GGUF 模型成功載入")
109
  except Exception as e:
110
- self._log(f"❌ GGUF 載入失敗: {e}", "CRITICAL")
111
- self.llm = None
112
 
113
  def huggingface_api_call(self, prompt: str) -> str:
114
  if self.llm is None: return ""
115
  try:
116
  output = self.llm(prompt, max_tokens=150, temperature=0.1, top_p=0.9, echo=False, stop=["```", ";", "\n\n", "</s>", "###", "Q:"], repeat_penalty=1.1)
117
- generated_text = output["choices"][0]["text"] if output and "choices" in output and len(output["choices"]) > 0 else ""
118
  self._log(f"🧠 模型原始輸出: {generated_text.strip()}", "DEBUG")
119
  return generated_text.strip()
120
  except Exception as e:
121
- self._log(f"❌ 模型生成錯誤: {e}", "CRITICAL")
122
- return ""
123
 
124
  def _load_schema(self) -> Dict:
125
  try:
126
  schema_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename="sqlite_schema_FULL.json", repo_type="dataset")
127
- with open(schema_path, "r", encoding="utf-8") as f:
128
- schema_data = json.load(f)
129
  self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格。")
130
  return schema_data
131
  except Exception as e:
132
- self._log(f"❌ 載入 schema 失敗: {e}", "ERROR")
133
- return {}
134
 
135
  def _encode_texts(self, texts):
136
  if isinstance(texts, str): texts = [texts]
137
- inputs = self.embed_tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512).to(DEVICE)
138
  with torch.no_grad():
139
  outputs = self.embed_model(**inputs)
140
- embeddings = outputs.last_hidden_state.mean(dim=1)
141
- return embeddings.cpu()
142
 
143
  def _load_and_index_dataset(self):
144
  try:
145
  dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
146
  dataset = dataset.filter(lambda ex: isinstance(ex.get("messages"), list) and len(ex["messages"]) >= 2)
147
- corpus = [item['messages'][0]['content'] for item in dataset]
148
  self._log(f"正在編碼 {len(corpus)} 個問題...")
149
  all_embeddings = torch.cat([self._encode_texts(corpus[i:i+32]) for i in range(0, len(corpus), 32)], dim=0).numpy()
150
- index = faiss.IndexFlatIP(all_embeddings.shape[1])
151
  index.add(all_embeddings.astype('float32'))
152
  self._log("✅ 向量索引建立完成")
153
  return dataset, index
154
  except Exception as e:
155
- self._log(f"❌ 載入數據失敗: {e}", "ERROR")
156
- self._log(traceback.format_exc(), "DEBUG")
157
- return None, None
158
 
159
  def _identify_relevant_tables(self, question: str) -> List[str]:
160
  question_lower = question.lower()
@@ -185,31 +147,27 @@ class TextToSQLSystem:
185
  q_embedding = self._encode_texts([question]).numpy().astype('float32')
186
  distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset)))
187
  results, seen_questions = [], set()
188
- for i, idx in enumerate(indices[0]):
189
  if len(results) >= top_k: break
190
  idx = int(idx)
191
  if idx >= len(self.dataset): continue
192
  item = self.dataset[idx]
193
  if not (isinstance(item.get('messages'), list) and len(item['messages']) >= 2): continue
194
- q_content = (item['messages'][0].get('content') or '').strip()
195
- a_content = (item['messages'][1].get('content') or '').strip()
196
  if not q_content or not a_content: continue
197
  clean_q = re.sub(r"以下是一個SQL查詢任務:\s*指令:\s*", "", q_content).strip()
198
  if clean_q in seen_questions: continue
199
  seen_questions.add(clean_q)
200
  sql = parse_sql_from_response(a_content) or "無法解析範例SQL"
201
- results.append({"similarity": float(distances[0][i]), "question": clean_q, "sql": sql})
202
  return results
203
  except Exception as e:
204
- self._log(f"❌ 檢索失敗: {e}", "ERROR")
205
- return []
206
 
207
  def _build_prompt(self, user_q: str, examples: List[Dict]) -> str:
208
  schema_str = self._format_relevant_schema(self._identify_relevant_tables(user_q))
209
- example_str = ""
210
- if examples:
211
- example_prompts = [f"Q: {ex['question']}\nA: ```sql\n{ex['sql']}\n```" for ex in examples]
212
- example_str = "\n---\n".join(example_prompts)
213
  prompt = f"""You are an expert SQLite programmer. Your task is to generate a SQL query based on the database schema and a user's question.
214
 
215
  ## Database Schema
@@ -238,12 +196,8 @@ A: ```sql
238
  entity_patterns = [
239
  {'pattern': r"(买家|buyer)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.BuyerID', 'type': '买家ID'},
240
  {'pattern': r"(申请方|申请厂商|applicant)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.ApplicantID', 'type': '申请方ID'},
241
- {'pattern': r"(付款方|付款厂商|invoiceto)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.InvoiceToID', 'type': '付款方ID'},
242
- {'pattern': r"(代理商|agent)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.AgentID', 'type': '代理商ID'},
243
  {'pattern': r"(买家|buyer|客戶)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.BuyerName', 'type': '买家'},
244
  {'pattern': r"(申请方|申请厂商|applicant)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.ApplicantName', 'type': '申请方'},
245
- {'pattern': r"(付款方|付款厂商|invoiceto)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.InvoiceToName', 'type': '付款方'},
246
- {'pattern': r"(代理商|agent)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.AgentName', 'type': '代理商'},
247
  {'pattern': r"\b([A-Z]\d{4}[A-Z])\b", 'column': 'sd.ApplicantID', 'type': 'ID'}
248
  ]
249
  for p in entity_patterns:
@@ -253,64 +207,38 @@ A: ```sql
253
  entity_match_data = {"type": p['type'], "name": entity_value.strip().upper(), "column": p['column']}
254
  break
255
 
256
- if any(kw in q_lower for kw in ['報告號碼', '報告清單', '列出報告', 'report number', 'list of reports']):
257
- year_match = re.search(r'(\d{4})\s*年?', question)
258
- month_match = re.search(r'(\d{1,2})\s*月', question)
259
- from_clause = "FROM JobTimeline AS jt"
260
- select_clause = "SELECT jt.JobNo, jt.ReportAuthorization"
261
- where_conditions = ["jt.ReportAuthorization IS NOT NULL"]
262
- log_parts = []
263
- if year_match: where_conditions.append(f"strftime('%Y', jt.ReportAuthorization) = '{year_match.group(1)}'"); log_parts.append(f"{year_match.group(1)}年")
264
- if month_match: where_conditions.append(f"strftime('%m', jt.ReportAuthorization) = '{month_match.group(1).zfill(2)}'"); log_parts.append(f"{month_match.group(1)}月")
265
- if 'fail' in q_lower or '失敗' in q_lower:
266
- if "JOIN TSR53SampleDescription" not in from_clause: from_clause += " JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
267
- where_conditions.append("sd.OverallRating = 'Fail'"); log_parts.append("Fail")
268
- elif 'pass' in q_lower or '通過' in q_lower:
269
- if "JOIN TSR53SampleDescription" not in from_clause: from_clause += " JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
270
- where_conditions.append("sd.OverallRating = 'Pass'"); log_parts.append("Pass")
271
  if entity_match_data:
272
- entity_name, column_name = entity_match_data["name"], entity_match_data["column"]
273
- if "JOIN TSR53SampleDescription" not in from_clause: from_clause += " JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
274
- match_operator = "=" if column_name.endswith("ID") else "LIKE"
275
- entity_value = f"'{entity_name}'" if match_operator == "=" else f"'%{entity_name}%'"
276
- where_conditions.append(f"{column_name} {match_operator} {entity_value}")
277
- log_parts.append(entity_name)
278
- select_clause = "SELECT jt.JobNo, sd.BuyerName, jt.ReportAuthorization"
279
- final_where_clause = "WHERE " + " AND ".join(where_conditions) if where_conditions else ""
280
- time_log = " ".join(log_parts) if log_parts else "全部"
281
- self._log(f"🔄 檢測到查詢【{time_log} 報告列表】意圖,啟用智能模板。", "INFO")
282
- template_sql = f"{select_clause} {from_clause} {final_where_clause} ORDER BY jt.ReportAuthorization DESC;"
283
- return self._finalize_sql(template_sql, f"模板覆寫: {time_log} 報告列表查詢")
284
 
285
  if '報告' in q_lower and any(kw in q_lower for kw in ['幾份', '多少', '數量', '總數']) and not entity_match_data:
286
  year_match = re.search(r'(\d{4})\s*年?', question)
287
- time_condition, time_log = "", ""
288
- if year_match:
289
- time_condition = f"WHERE ReportAuthorization IS NOT NULL AND strftime('%Y', ReportAuthorization) = '{year_match.group(1)}'"
290
- time_log = f"{year_match.group(1)}"
291
- else:
292
- time_condition = "WHERE ReportAuthorization IS NOT NULL"
293
- self._log(f"🔄 檢測到查詢【{time_log}全局報告總數】意圖,啟用模板。", "INFO")
294
- template_sql = f"SELECT COUNT(DISTINCT JobNo) AS report_count FROM JobTimeline {time_condition};"
295
- return self._finalize_sql(template_sql, f"模板覆寫: {time_log}全局報告總數查詢")
296
-
297
- self._log("未觸發任何模板,嘗試解析並修正 AI 輸出...", "INFO")
298
  parsed_sql = parse_sql_from_response(raw_response)
299
- if not parsed_sql:
300
- return None, f"無法解析SQL。原始回應:\n{raw_response}"
301
- fixed_sql = " " + parsed_sql.strip() + " "
302
- fixes_applied_fallback = []
303
- dialect_corrections = {r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)"}
304
- for p, r in dialect_corrections.items():
305
- if re.search(p, fixed_sql, re.IGNORECASE):
306
- fixed_sql = re.sub(p, r, fixed_sql, flags=re.IGNORECASE); fixes_applied_fallback.append(f"修正方言: {p}")
307
- schema_corrections = {'TSR53Report':'TSR53SampleDescription', 'TSR53InvoiceReportNo':'JobNo', 'Status':'OverallRating'}
308
- for w, c in schema_corrections.items():
309
- pattern = r'\b' + re.escape(w) + r'\b'
310
- if re.search(pattern, fixed_sql, re.IGNORECASE):
311
- fixed_sql = re.sub(pattern, c, fixed_sql, flags=re.IGNORECASE); fixes_applied_fallback.append(f"映射 Schema: '{w}' -> '{c}'")
312
- log_msg = "AI 生成並成功修正" if fixes_applied_fallback else "AI 生成且無需修正"
313
- return self._finalize_sql(fixed_sql, log_msg)
314
 
315
  def process_question(self, question: str) -> Tuple[str, str]:
316
  if question in self.query_cache: self._log("⚡ 使用緩存結果"); return self.query_cache[question]
@@ -323,42 +251,28 @@ A: ```sql
323
  self._log("🧠 開始生成 AI 回應...")
324
  response = self.huggingface_api_call(prompt)
325
  final_sql, status_message = self._validate_and_fix_sql(question, response)
326
- if not final_sql: result = (status_message, "生成失敗")
327
- else: result = (final_sql, status_message)
328
  self.query_cache[question] = result
329
  return result
330
 
331
  # ==================== Gradio 介面 ====================
332
  text_to_sql_system = TextToSQLSystem()
333
-
334
  def process_query(q: str):
335
  if not q.strip(): return "", "等待輸入", "請輸入問題"
336
- if text_to_sql_system.llm is None:
337
- return "模型未能成功載入,請檢查終端日誌。", "模型載入失敗", "\n".join(text_to_sql_system.log_history)
338
  sql, status = text_to_sql_system.process_question(q)
339
- logs = "\n".join(text_to_sql_system.log_history[-15:])
340
- return sql, status, logs
341
 
342
- examples = [
343
- "2024年7月買家 Gap 的 Fail 報告號碼",
344
- "列出2023年所有失败的报告",
345
- "找出总金额最高的10个工作单",
346
- "哪些客户的工作单数量最多?",
347
- "A組2024年完成了多少個測試項目?",
348
- "2024年每月完成多少份報告?"
349
- ]
350
  with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
351
  gr.Markdown("# ⚡ Text-to-SQL 智能助手 (终极版)")
352
- gr.Markdown("融合了模板引擎和 GGUF 模型的强大版本")
353
  with gr.Row():
354
  with gr.Column(scale=2):
355
  inp = gr.Textbox(lines=3, label="💬 您的問題", placeholder="例如:2024年每月完成多少份報告?")
356
  btn = gr.Button("🚀 生成 SQL", variant="primary")
357
  status = gr.Textbox(label="狀態", interactive=False)
358
- with gr.Column(scale=3):
359
- sql_out = gr.Code(label="🤖 生成的 SQL", language="sql", lines=8)
360
- with gr.Accordion("📋 處理日誌", open=False):
361
- logs = gr.Textbox(lines=10, label="日誌", interactive=False)
362
  gr.Examples(examples=examples, inputs=inp, label="💡 點擊試用範例問題")
363
  btn.click(process_query, inputs=[inp], outputs=[sql_out, status, logs])
364
  inp.submit(process_query, inputs=[inp], outputs=[sql_out, status, logs])
 
1
  # ==============================================================================
2
+ # Text-to-SQL 智能助手 - Hugging Face CPU 最终版 v7
3
+ # (修复所有已知 Bug)
4
  # ==============================================================================
5
  import gradio as gr
6
  import os
 
22
  import torch.nn.functional as F
23
 
24
  # ==================== 配置參數 ====================
 
25
  GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q8_0.gguf"
26
+ N_GPU_LAYERS = 0 # CPU 环境
 
27
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
28
  GGUF_REPO_ID = "Paul720810/gguf-models"
29
  FEW_SHOT_EXAMPLES_COUNT = 1
30
+ DEVICE = "cpu"
31
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
 
32
  TEMP_DIR = tempfile.gettempdir()
33
  os.makedirs(os.path.join(TEMP_DIR, 'text_to_sql_cache'), exist_ok=True)
34
 
35
  print("=" * 60)
36
+ print("🤖 Text-to-SQL 智能助手 v7.0 (Hugging Face CPU 版)...")
37
  print(f"🚀 模型: {GGUF_FILENAME}")
38
  print(f"💻 設備: {DEVICE} (GPU Layers: {N_GPU_LAYERS})")
39
  print("=" * 60)
40
 
41
  # ==================== 工具函數 ====================
42
+ def get_current_time(): return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
 
 
43
  def format_log(message: str, level: str = "INFO") -> str:
44
+ log_entry = f"[{get_current_time()}] [{level.upper()}] {message}"; print(log_entry); return log_entry
 
 
 
45
  def parse_sql_from_response(response_text: str) -> Optional[str]:
46
  if not response_text: return None
47
  response_text = response_text.strip()
48
  match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
49
  if match: return match.group(1).strip()
50
+ return response_text if response_text.upper().startswith("SELECT") else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # ==================== Text-to-SQL 核心類 ====================
53
  class TextToSQLSystem:
 
59
  self._log(f"載入嵌入模型: {embed_model_name}")
60
  self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
61
  self.embed_model = AutoModel.from_pretrained(embed_model_name)
 
 
 
62
  self.schema = self._load_schema()
63
  self.dataset, self.faiss_index = self._load_and_index_dataset()
64
  self._load_gguf_model()
65
  self._log("✅ 系統初始化完成")
66
  except Exception as e:
67
+ self._log(f"❌ 系統初始化過程中發生嚴重錯誤: {e}", "CRITICAL"); self._log(traceback.format_exc(), "DEBUG"); self.llm = None
 
 
 
 
 
68
 
69
+ def _log(self, message: str, level: str = "INFO"): self.log_history.append(format_log(message, level))
70
+
71
  def _load_gguf_model(self):
72
  try:
73
  model_path = hf_hub_download(repo_id=GGUF_REPO_ID, filename=GGUF_FILENAME, repo_type="dataset", cache_dir=TEMP_DIR)
 
76
  self.llm = Llama(model_path=model_path, n_ctx=2048, n_threads=4, n_batch=512, verbose=False, n_gpu_layers=N_GPU_LAYERS)
77
  self._log("✅ GGUF 模型成功載入")
78
  except Exception as e:
79
+ self._log(f"❌ GGUF 載入失敗: {e}", "CRITICAL"); self.llm = None
 
80
 
81
  def huggingface_api_call(self, prompt: str) -> str:
82
  if self.llm is None: return ""
83
  try:
84
  output = self.llm(prompt, max_tokens=150, temperature=0.1, top_p=0.9, echo=False, stop=["```", ";", "\n\n", "</s>", "###", "Q:"], repeat_penalty=1.1)
85
+ generated_text = output["choices"]["text"] if output and "choices" in output and len(output["choices"]) > 0 else ""
86
  self._log(f"🧠 模型原始輸出: {generated_text.strip()}", "DEBUG")
87
  return generated_text.strip()
88
  except Exception as e:
89
+ self._log(f"❌ 模型生成錯誤: {e}", "CRITICAL"); return ""
 
90
 
91
  def _load_schema(self) -> Dict:
92
  try:
93
  schema_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename="sqlite_schema_FULL.json", repo_type="dataset")
94
+ with open(schema_path, "r", encoding="utf-8") as f: schema_data = json.load(f)
 
95
  self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格。")
96
  return schema_data
97
  except Exception as e:
98
+ self._log(f"❌ 載入 schema 失敗: {e}", "ERROR"); return {}
 
99
 
100
  def _encode_texts(self, texts):
101
  if isinstance(texts, str): texts = [texts]
102
+ inputs = self.embed_tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
103
  with torch.no_grad():
104
  outputs = self.embed_model(**inputs)
105
+ return outputs.last_hidden_state.mean(dim=1).cpu()
 
106
 
107
  def _load_and_index_dataset(self):
108
  try:
109
  dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
110
  dataset = dataset.filter(lambda ex: isinstance(ex.get("messages"), list) and len(ex["messages"]) >= 2)
111
+ corpus = [item['messages']['content'] for item in dataset]
112
  self._log(f"正在編碼 {len(corpus)} 個問題...")
113
  all_embeddings = torch.cat([self._encode_texts(corpus[i:i+32]) for i in range(0, len(corpus), 32)], dim=0).numpy()
114
+ index = faiss.IndexFlatIP(all_embeddings.shape)
115
  index.add(all_embeddings.astype('float32'))
116
  self._log("✅ 向量索引建立完成")
117
  return dataset, index
118
  except Exception as e:
119
+ self._log(f"❌ 載入數據失敗: {e}", "ERROR"); self._log(traceback.format_exc(), "DEBUG"); return None, None
 
 
120
 
121
  def _identify_relevant_tables(self, question: str) -> List[str]:
122
  question_lower = question.lower()
 
147
  q_embedding = self._encode_texts([question]).numpy().astype('float32')
148
  distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset)))
149
  results, seen_questions = [], set()
150
+ for i, idx in enumerate(indices):
151
  if len(results) >= top_k: break
152
  idx = int(idx)
153
  if idx >= len(self.dataset): continue
154
  item = self.dataset[idx]
155
  if not (isinstance(item.get('messages'), list) and len(item['messages']) >= 2): continue
156
+ q_content = (item['messages'].get('content') or '').strip()
157
+ a_content = (item['messages'].get('content') or '').strip()
158
  if not q_content or not a_content: continue
159
  clean_q = re.sub(r"以下是一個SQL查詢任務:\s*指令:\s*", "", q_content).strip()
160
  if clean_q in seen_questions: continue
161
  seen_questions.add(clean_q)
162
  sql = parse_sql_from_response(a_content) or "無法解析範例SQL"
163
+ results.append({"similarity": float(distances[i]), "question": clean_q, "sql": sql})
164
  return results
165
  except Exception as e:
166
+ self._log(f"❌ 檢索失敗: {e}", "ERROR"); return []
 
167
 
168
  def _build_prompt(self, user_q: str, examples: List[Dict]) -> str:
169
  schema_str = self._format_relevant_schema(self._identify_relevant_tables(user_q))
170
+ example_str = "\n---\n".join([f"Q: {ex['question']}\nA: ```sql\n{ex['sql']}\n```" for ex in examples]) if examples else "No examples provided."
 
 
 
171
  prompt = f"""You are an expert SQLite programmer. Your task is to generate a SQL query based on the database schema and a user's question.
172
 
173
  ## Database Schema
 
196
  entity_patterns = [
197
  {'pattern': r"(买家|buyer)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.BuyerID', 'type': '买家ID'},
198
  {'pattern': r"(申请方|申请厂商|applicant)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.ApplicantID', 'type': '申请方ID'},
 
 
199
  {'pattern': r"(买家|buyer|客戶)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.BuyerName', 'type': '买家'},
200
  {'pattern': r"(申请方|申请厂商|applicant)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.ApplicantName', 'type': '申请方'},
 
 
201
  {'pattern': r"\b([A-Z]\d{4}[A-Z])\b", 'column': 'sd.ApplicantID', 'type': 'ID'}
202
  ]
203
  for p in entity_patterns:
 
207
  entity_match_data = {"type": p['type'], "name": entity_value.strip().upper(), "column": p['column']}
208
  break
209
 
210
+ if any(kw in q_lower for kw in ['報告號碼', '報告清單', '列出報告']):
211
+ year_match, month_match = re.search(r'(\d{4})\s*年?', question), re.search(r'(\d{1,2})\s*月', question)
212
+ from_clause = "FROM JobTimeline AS jt JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
213
+ select_clause = "SELECT jt.JobNo, sd.BuyerName, jt.ReportAuthorization"
214
+ where, log_parts = ["jt.ReportAuthorization IS NOT NULL"], []
215
+ if year_match: where.append(f"strftime('%Y', jt.ReportAuthorization) = '{year_match.group(1)}'"); log_parts.append(f"{year_match.group(1)}年")
216
+ if month_match: where.append(f"strftime('%m', jt.ReportAuthorization) = '{month_match.group(1).zfill(2)}'"); log_parts.append(f"{month_match.group(1)}月")
217
+ if 'fail' in q_lower or '失敗' in q_lower: where.append("sd.OverallRating = 'Fail'"); log_parts.append("Fail")
218
+ elif 'pass' in q_lower or '通過' in q_lower: where.append("sd.OverallRating = 'Pass'"); log_parts.append("Pass")
 
 
 
 
 
 
219
  if entity_match_data:
220
+ op = "=" if entity_match_data["column"].endswith("ID") else "LIKE"
221
+ val = f"'{entity_match_data['name']}'" if op == "=" else f"'%{entity_match_data['name']}%'"
222
+ where.append(f"{entity_match_data['column']} {op} {val}"); log_parts.append(entity_match_data["name"])
223
+ final_where = "WHERE " + " AND ".join(where) if where else ""
224
+ log_msg = " ".join(log_parts) if log_parts else "全部"
225
+ self._log(f"🔄 檢測到【{log_msg} 報告列表】意圖,啟用模板。", "INFO")
226
+ template_sql = f"{select_clause} {from_clause} {final_where} ORDER BY jt.ReportAuthorization DESC;"
227
+ return self._finalize_sql(template_sql, f"模板覆寫: {log_msg}")
 
 
 
 
228
 
229
  if '報告' in q_lower and any(kw in q_lower for kw in ['幾份', '多少', '數量', '總數']) and not entity_match_data:
230
  year_match = re.search(r'(\d{4})\s*年?', question)
231
+ time_cond = f"WHERE strftime('%Y', ReportAuthorization) = '{year_match.group(1)}'" if year_match else ""
232
+ time_log = f"{year_match.group(1)}年" if year_match else "总"
233
+ self._log(f"🔄 檢測到【{time_log}报告总数】意图,启用模板。", "INFO")
234
+ template_sql = f"SELECT COUNT(DISTINCT JobNo) AS report_count FROM JobTimeline {time_cond};"
235
+ return self._finalize_sql(template_sql, f"模板覆寫: {time_log}")
236
+
237
+ # Fallback to AI
238
+ self._log("未觸發模板,由 AI 生成...", "INFO")
 
 
 
239
  parsed_sql = parse_sql_from_response(raw_response)
240
+ if not parsed_sql: return None, f"無法解析SQL: {raw_response}"
241
+ return self._finalize_sql(parsed_sql, "AI 生成")
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  def process_question(self, question: str) -> Tuple[str, str]:
244
  if question in self.query_cache: self._log("⚡ 使用緩存結果"); return self.query_cache[question]
 
251
  self._log("🧠 開始生成 AI 回應...")
252
  response = self.huggingface_api_call(prompt)
253
  final_sql, status_message = self._validate_and_fix_sql(question, response)
254
+ result = (final_sql, status_message) if final_sql else (status_message, "生成失敗")
 
255
  self.query_cache[question] = result
256
  return result
257
 
258
  # ==================== Gradio 介面 ====================
259
  text_to_sql_system = TextToSQLSystem()
 
260
  def process_query(q: str):
261
  if not q.strip(): return "", "等待輸入", "請輸入問題"
262
+ if text_to_sql_system.llm is None: return "模型未能成功載入...", "模型載入失敗", "\n".join(text_to_sql_system.log_history)
 
263
  sql, status = text_to_sql_system.process_question(q)
264
+ return sql, status, "\n".join(text_to_sql_system.log_history[-15:])
 
265
 
266
+ examples = ["2024年7月買家 Gap 的 Fail 報告號碼", "列出2023年所有失败的报告", "找出总金额最高的10个工作单", "哪些客户的工作单数量最多?", "A組2024年完成了多少個測試項目?", "2024年每月完成多少份報告?"]
 
 
 
 
 
 
 
267
  with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
268
  gr.Markdown("# ⚡ Text-to-SQL 智能助手 (终极版)")
 
269
  with gr.Row():
270
  with gr.Column(scale=2):
271
  inp = gr.Textbox(lines=3, label="💬 您的問題", placeholder="例如:2024年每月完成多少份報告?")
272
  btn = gr.Button("🚀 生成 SQL", variant="primary")
273
  status = gr.Textbox(label="狀態", interactive=False)
274
+ with gr.Column(scale=3): sql_out = gr.Code(label="🤖 生成的 SQL", language="sql", lines=8)
275
+ with gr.Accordion("📋 處理日誌", open=False): logs = gr.Textbox(lines=10, label="日誌", interactive=False)
 
 
276
  gr.Examples(examples=examples, inputs=inp, label="💡 點擊試用範例問題")
277
  btn.click(process_query, inputs=[inp], outputs=[sql_out, status, logs])
278
  inp.submit(process_query, inputs=[inp], outputs=[sql_out, status, logs])