Paul720810 commited on
Commit
5f5667d
·
verified ·
1 Parent(s): a487546

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +726 -155
app.py CHANGED
@@ -1,53 +1,92 @@
1
- # ==============================================================================
2
- # Text-to-SQL 智能助手 - Hugging Face CPU 最终版 v7
3
- # (修复所有已知 Bug)
4
- # ==============================================================================
5
  import gradio as gr
6
  import os
7
  import re
8
  import json
9
  import torch
10
  import numpy as np
11
- import gc
12
- import tempfile
13
  from datetime import datetime
14
  from datasets import load_dataset
15
  from huggingface_hub import hf_hub_download
16
  from llama_cpp import Llama
17
  from typing import List, Dict, Tuple, Optional
18
  import faiss
19
- import traceback
20
 
 
21
  from transformers import AutoModel, AutoTokenizer
22
  import torch.nn.functional as F
23
 
24
- # ==================== 配置參數 ====================
25
- GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q8_0.gguf"
26
- N_GPU_LAYERS = 0 # CPU 环境
27
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
28
  GGUF_REPO_ID = "Paul720810/gguf-models"
 
 
 
 
 
 
29
  FEW_SHOT_EXAMPLES_COUNT = 1
30
- DEVICE = "cpu"
31
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
32
- TEMP_DIR = tempfile.gettempdir()
33
- os.makedirs(os.path.join(TEMP_DIR, 'text_to_sql_cache'), exist_ok=True)
34
 
35
  print("=" * 60)
36
- print("🤖 Text-to-SQL 智能助手 v7.0 (Hugging Face CPU 版)...")
37
- print(f"🚀 模型: {GGUF_FILENAME}")
38
- print(f"💻 設備: {DEVICE} (GPU Layers: {N_GPU_LAYERS})")
 
39
  print("=" * 60)
40
 
41
  # ==================== 工具函數 ====================
42
- def get_current_time(): return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
 
 
43
  def format_log(message: str, level: str = "INFO") -> str:
44
- log_entry = f"[{get_current_time()}] [{level.upper()}] {message}"; print(log_entry); return log_entry
 
45
  def parse_sql_from_response(response_text: str) -> Optional[str]:
46
- if not response_text: return None
 
 
 
 
47
  response_text = response_text.strip()
 
 
48
  match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
49
- if match: return match.group(1).strip()
50
- return response_text if response_text.upper().startswith("SELECT") else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # ==================== Text-to-SQL 核心類 ====================
53
  class TextToSQLSystem:
@@ -55,227 +94,759 @@ class TextToSQLSystem:
55
  self.log_history = []
56
  self._log("初始化系統...")
57
  self.query_cache = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  try:
59
- self._log(f"載入嵌入模型: {embed_model_name}")
60
- self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
61
- self.embed_model = AutoModel.from_pretrained(embed_model_name)
62
- self.schema = self._load_schema()
63
- self.dataset, self.faiss_index = self._load_and_index_dataset()
64
- self._load_gguf_model()
65
- self._log("✅ 系統初始化完成")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  except Exception as e:
67
- self._log(f"❌ 系統初始化過程中發生嚴重錯誤: {e}", "CRITICAL"); self._log(traceback.format_exc(), "DEBUG"); self.llm = None
 
 
68
 
69
- def _log(self, message: str, level: str = "INFO"): self.log_history.append(format_log(message, level))
70
-
71
- def _load_gguf_model(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  try:
73
- model_path = hf_hub_download(repo_id=GGUF_REPO_ID, filename=GGUF_FILENAME, repo_type="dataset", cache_dir=TEMP_DIR)
74
- self._log(f"模型路徑: {model_path}")
75
- self._log(f"載入 GGUF 模型 (GPU Layers: {N_GPU_LAYERS})...")
76
- self.llm = Llama(model_path=model_path, n_ctx=2048, n_threads=4, n_batch=512, verbose=False, n_gpu_layers=N_GPU_LAYERS)
77
- self._log("✅ GGUF 模型成功載入")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  except Exception as e:
79
- self._log(f"❌ GGUF 載入失敗: {e}", "CRITICAL"); self.llm = None
 
80
 
81
  def huggingface_api_call(self, prompt: str) -> str:
82
- if self.llm is None: return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  try:
84
- output = self.llm(prompt, max_tokens=150, temperature=0.1, top_p=0.9, echo=False, stop=["```", ";", "\n\n", "</s>", "###", "Q:"], repeat_penalty=1.1)
85
- generated_text = output["choices"]["text"] if output and "choices" in output and len(output["choices"]) > 0 else ""
86
- self._log(f"🧠 模型原始輸出: {generated_text.strip()}", "DEBUG")
87
- return generated_text.strip()
 
 
 
 
 
 
 
 
88
  except Exception as e:
89
- self._log(f"❌ 模型生成錯誤: {e}", "CRITICAL"); return ""
90
-
 
 
 
 
 
91
  def _load_schema(self) -> Dict:
 
92
  try:
93
- schema_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename="sqlite_schema_FULL.json", repo_type="dataset")
94
- with open(schema_path, "r", encoding="utf-8") as f: schema_data = json.load(f)
95
- self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格。")
96
- return schema_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  except Exception as e:
98
- self._log(f"❌ 載入 schema 失敗: {e}", "ERROR"); return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  def _encode_texts(self, texts):
101
- if isinstance(texts, str): texts = [texts]
102
- inputs = self.embed_tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
 
 
 
 
 
 
 
103
  with torch.no_grad():
104
  outputs = self.embed_model(**inputs)
105
- return outputs.last_hidden_state.mean(dim=1).cpu()
 
 
 
106
 
107
  def _load_and_index_dataset(self):
 
108
  try:
109
  dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
110
- dataset = dataset.filter(lambda ex: isinstance(ex.get("messages"), list) and len(ex["messages"]) >= 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  corpus = [item['messages'][0]['content'] for item in dataset]
112
  self._log(f"正在編碼 {len(corpus)} 個問題...")
113
- all_embeddings = torch.cat([self._encode_texts(corpus[i:i+32]) for i in range(0, len(corpus), 32)], dim=0).numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  index = faiss.IndexFlatIP(all_embeddings.shape[1])
115
  index.add(all_embeddings.astype('float32'))
 
116
  self._log("✅ 向量索引建立完成")
117
  return dataset, index
 
118
  except Exception as e:
119
- self._log(f"❌ 載入數據失敗: {e}", "ERROR"); self._log(traceback.format_exc(), "DEBUG"); return None, None
120
-
 
121
  def _identify_relevant_tables(self, question: str) -> List[str]:
 
122
  question_lower = question.lower()
123
  relevant_tables = []
124
- keyword_to_table = {'TSR53SampleDescription': ['客戶', '買方', '申請', '發票對象'], 'JobsInProgress': ['進行中', '買家', '申請方'], 'JobTimeline': ['時間', '完成', '創建', '實驗室'], 'TSR53Invoice': ['發票', '金額', '費用']}
 
 
 
 
 
 
 
 
 
 
125
  for table, keywords in keyword_to_table.items():
126
- if any(keyword in question_lower for keyword in keywords): relevant_tables.append(table)
127
- if not relevant_tables: return ['TSR53SampleDescription', 'JobsInProgress', 'JobTimeline']
128
- return relevant_tables[:3]
 
 
 
 
 
 
 
 
 
 
129
 
130
  def _format_relevant_schema(self, table_names: List[str]) -> str:
131
- if not self.schema: return "No schema available.\n"
132
- formatted = ""
 
 
 
 
 
 
133
  for table in table_names:
 
 
 
 
 
 
 
 
 
 
 
 
134
  if table in self.schema:
 
135
  formatted += f"Table: {table}\n"
136
  cols_str = []
 
137
  for col in self.schema[table][:10]:
138
- col_name, col_type, col_desc = col['name'], col['type'], col.get('description', '').replace('\n', ' ')
139
- if col_desc: cols_str.append(f"{col_name} ({col_type}, {col_desc})")
140
- else: cols_str.append(f"{col_name} ({col_type})")
 
 
 
 
 
141
  formatted += f"Columns: {', '.join(cols_str)}\n\n"
142
- return formatted.strip()
143
-
144
- def find_most_similar(self, question: str, top_k: int) -> List[Dict]:
145
- if self.faiss_index is None: return []
146
- try:
147
- q_embedding = self._encode_texts([question]).numpy().astype('float32')
148
- distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset)))
149
- results, seen_questions = [], set()
150
- for i, idx in enumerate(indices):
151
- if len(results) >= top_k: break
152
- idx = int(idx)
153
- if idx >= len(self.dataset): continue
154
- item = self.dataset[idx]
155
- if not (isinstance(item.get('messages'), list) and len(item['messages']) >= 2): continue
156
- q_content = (item['messages'][0].get('content') or '').strip()
157
- a_content = (item['messages'][1].get('content') or '').strip()
158
- if not q_content or not a_content: continue
159
- clean_q = re.sub(r"以下是一個SQL查詢任務:\s*指令:\s*", "", q_content).strip()
160
- if clean_q in seen_questions: continue
161
- seen_questions.add(clean_q)
162
- sql = parse_sql_from_response(a_content) or "無法解析範例SQL"
163
- results.append({"similarity": float(distances[i]), "question": clean_q, "sql": sql})
164
- return results
165
- except Exception as e:
166
- self._log(f"❌ 檢索失敗: {e}", "ERROR"); return []
167
-
168
- def _build_prompt(self, user_q: str, examples: List[Dict]) -> str:
169
- schema_str = self._format_relevant_schema(self._identify_relevant_tables(user_q))
170
- example_str = "\n---\n".join([f"Q: {ex['question']}\nA: ```sql\n{ex['sql']}\n```" for ex in examples]) if examples else "No examples provided."
171
- prompt = f"""You are an expert SQLite programmer. Your task is to generate a SQL query based on the database schema and a user's question.
172
-
173
- ## Database Schema
174
- {schema_str.strip()}
175
 
176
- ## Examples
177
- {example_str.strip()}
178
 
179
- ## Task
180
- Based on the schema and examples, generate the SQL query for the following question.
181
- Q: {user_q}
182
- A: ```sql
183
- """
184
- return prompt
185
-
186
- def _finalize_sql(self, sql: str, log_message: str) -> Tuple[str, str]:
187
- final_sql = re.sub(r'\s+', ' ', sql.strip())
188
- if not final_sql.endswith(';'): final_sql += ';'
189
- self._log(f"✅ SQL 已生成 ({log_message})", "INFO")
190
- self._log(f" - 最終 SQL: {final_sql}", "DEBUG")
191
- return final_sql, "生成成功"
192
 
193
  def _validate_and_fix_sql(self, question: str, raw_response: str) -> Tuple[Optional[str], str]:
 
 
 
 
 
 
194
  q_lower = question.lower()
 
 
 
 
 
 
195
  entity_match_data = None
 
 
196
  entity_patterns = [
 
197
  {'pattern': r"(买家|buyer)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.BuyerID', 'type': '买家ID'},
198
  {'pattern': r"(申请方|申请厂商|applicant)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.ApplicantID', 'type': '申请方ID'},
 
 
 
 
199
  {'pattern': r"(买家|buyer|客戶)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.BuyerName', 'type': '买家'},
200
  {'pattern': r"(申请方|申请厂商|applicant)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.ApplicantName', 'type': '申请方'},
 
 
 
 
201
  {'pattern': r"\b([A-Z]\d{4}[A-Z])\b", 'column': 'sd.ApplicantID', 'type': 'ID'}
202
  ]
 
203
  for p in entity_patterns:
204
  match = re.search(p['pattern'], question, re.IGNORECASE)
205
  if match:
206
  entity_value = match.group(2) if len(match.groups()) > 1 else match.group(1)
207
- entity_match_data = {"type": p['type'], "name": entity_value.strip().upper(), "column": p['column']}
 
 
 
 
208
  break
209
 
210
- if any(kw in q_lower for kw in ['報告號碼', '報告清單', '列出報告']):
211
- year_match, month_match = re.search(r'(\d{4})\s*年?', question), re.search(r'(\d{1,2})\s*月', question)
212
- from_clause = "FROM JobTimeline AS jt JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
213
- select_clause = "SELECT jt.JobNo, sd.BuyerName, jt.ReportAuthorization"
214
- where, log_parts = ["jt.ReportAuthorization IS NOT NULL"], []
215
- if year_match: where.append(f"strftime('%Y', jt.ReportAuthorization) = '{year_match.group(1)}'"); log_parts.append(f"{year_match.group(1)}年")
216
- if month_match: where.append(f"strftime('%m', jt.ReportAuthorization) = '{month_match.group(1).zfill(2)}'"); log_parts.append(f"{month_match.group(1)}月")
217
- if 'fail' in q_lower or '失敗' in q_lower: where.append("sd.OverallRating = 'Fail'"); log_parts.append("Fail")
218
- elif 'pass' in q_lower or '通過' in q_lower: where.append("sd.OverallRating = 'Pass'"); log_parts.append("Pass")
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  if entity_match_data:
220
- op = "=" if entity_match_data["column"].endswith("ID") else "LIKE"
221
- val = f"'{entity_match_data['name']}'" if op == "=" else f"'%{entity_match_data['name']}%'"
222
- where.append(f"{entity_match_data['column']} {op} {val}"); log_parts.append(entity_match_data["name"])
223
- final_where = "WHERE " + " AND ".join(where) if where else ""
224
- log_msg = " ".join(log_parts) if log_parts else "全部"
225
- self._log(f"🔄 檢測到【{log_msg} 報告列表】意圖,啟用模板。", "INFO")
226
- template_sql = f"{select_clause} {from_clause} {final_where} ORDER BY jt.ReportAuthorization DESC;"
227
- return self._finalize_sql(template_sql, f"模板覆寫: {log_msg}")
228
-
229
- if '報告' in q_lower and any(kw in q_lower for kw in ['幾份', '多少', '數量', '總數']) and not entity_match_data:
 
 
 
 
 
 
230
  year_match = re.search(r'(\d{4})\s*年?', question)
231
- time_cond = f"WHERE strftime('%Y', ReportAuthorization) = '{year_match.group(1)}'" if year_match else ""
232
- time_log = f"{year_match.group(1)}年" if year_match else "总"
233
- self._log(f"🔄 檢測到【{time_log}报告总数】意图,启用模板。", "INFO")
234
- template_sql = f"SELECT COUNT(DISTINCT JobNo) AS report_count FROM JobTimeline {time_cond};"
235
- return self._finalize_sql(template_sql, f"模板覆寫: {time_log}")
 
 
 
 
 
 
 
 
 
 
236
 
237
- # Fallback to AI
238
- self._log("未觸發模板,由 AI 生成...", "INFO")
239
  parsed_sql = parse_sql_from_response(raw_response)
240
- if not parsed_sql: return None, f"無法解析SQL: {raw_response}"
241
- return self._finalize_sql(parsed_sql, "AI 生成")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  def process_question(self, question: str) -> Tuple[str, str]:
244
- if question in self.query_cache: self._log("⚡ 使用緩存結果"); return self.query_cache[question]
 
 
 
 
 
245
  self.log_history = []
246
  self._log(f"⏰ 處理問題: {question}")
 
 
 
247
  examples = self.find_most_similar(question, FEW_SHOT_EXAMPLES_COUNT)
248
  if examples: self._log(f"✅ 找到 {len(examples)} 個相似範例")
 
 
 
249
  prompt = self._build_prompt(question, examples)
250
- self._log(f"📏 Prompt 長度: {len(prompt)} 字符")
 
251
  self._log("🧠 開始生成 AI 回應...")
252
  response = self.huggingface_api_call(prompt)
 
 
253
  final_sql, status_message = self._validate_and_fix_sql(question, response)
254
- result = (final_sql, status_message) if final_sql else (status_message, "生成失敗")
 
 
 
 
 
 
255
  self.query_cache[question] = result
256
  return result
257
 
258
  # ==================== Gradio 介面 ====================
259
  text_to_sql_system = TextToSQLSystem()
 
260
  def process_query(q: str):
261
- if not q.strip(): return "", "等待輸入", "請輸入問題"
262
- if text_to_sql_system.llm is None: return "模型未能成功載入...", "模型載入失敗", "\n".join(text_to_sql_system.log_history)
 
263
  sql, status = text_to_sql_system.process_question(q)
264
- return sql, status, "\n".join(text_to_sql_system.log_history[-15:])
 
 
 
 
 
 
 
 
 
 
 
265
 
266
- examples = ["2024年7月買家 Gap 的 Fail 報告號碼", "列出2023年所有失败的报告", "找出总金额最高的10个工作单", "哪些客户的工作单数量最多?", "A組2024年完成了多少個測試項目?", "2024年每月完成多少份報告?"]
267
  with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
268
- gr.Markdown("# ⚡ Text-to-SQL 智能助手 (终极版)")
 
 
269
  with gr.Row():
270
  with gr.Column(scale=2):
271
  inp = gr.Textbox(lines=3, label="💬 您的問題", placeholder="例如:2024年每月完成多少份報告?")
272
  btn = gr.Button("🚀 生成 SQL", variant="primary")
273
  status = gr.Textbox(label="狀態", interactive=False)
274
- with gr.Column(scale=3): sql_out = gr.Code(label="🤖 生成的 SQL", language="sql", lines=8)
275
- with gr.Accordion("📋 處理日誌", open=False): logs = gr.Textbox(lines=10, label="日誌", interactive=False)
276
- gr.Examples(examples=examples, inputs=inp, label="💡 點擊試用範例問題")
 
 
 
 
 
 
 
 
 
 
 
 
277
  btn.click(process_query, inputs=[inp], outputs=[sql_out, status, logs])
278
  inp.submit(process_query, inputs=[inp], outputs=[sql_out, status, logs])
279
 
280
  if __name__ == "__main__":
281
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import os
3
  import re
4
  import json
5
  import torch
6
  import numpy as np
 
 
7
  from datetime import datetime
8
  from datasets import load_dataset
9
  from huggingface_hub import hf_hub_download
10
  from llama_cpp import Llama
11
  from typing import List, Dict, Tuple, Optional
12
  import faiss
13
+ from functools import lru_cache
14
 
15
+ # 使用 transformers 替代 sentence-transformers
16
  from transformers import AutoModel, AutoTokenizer
17
  import torch.nn.functional as F
18
 
19
+ # ==================== 配置區 ====================
 
 
20
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
21
  GGUF_REPO_ID = "Paul720810/gguf-models"
22
+ #GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q4_k_m.gguf"
23
+ GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q8_0.gguf"
24
+
25
+ # 添加這一行:你的原始微調模型路徑
26
+ FINETUNED_MODEL_PATH = "Paul720810/qwen2.5-coder-1.5b-sql-finetuned" # ← 新增這行
27
+
28
  FEW_SHOT_EXAMPLES_COUNT = 1
29
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
 
 
31
 
32
  print("=" * 60)
33
+ print("🤖 Text-to-SQL 系統啟動中...")
34
+ print(f"📊 數據集: {DATASET_REPO_ID}")
35
+ print(f"🤖 嵌入模型: {EMBED_MODEL_NAME}")
36
+ print(f"💻 設備: {DEVICE}")
37
  print("=" * 60)
38
 
39
  # ==================== 工具函數 ====================
40
+ def get_current_time():
41
+ return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
42
+
43
  def format_log(message: str, level: str = "INFO") -> str:
44
+ return f"[{get_current_time()}] [{level.upper()}] {message}"
45
+
46
  def parse_sql_from_response(response_text: str) -> Optional[str]:
47
+ """從模型輸出提取 SQL,增強版"""
48
+ if not response_text:
49
+ return None
50
+
51
+ # 清理回應文本
52
  response_text = response_text.strip()
53
+
54
+ # 1. 先找 ```sql ... ```
55
  match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
56
+ if match:
57
+ return match.group(1).strip()
58
+
59
+ # 2. 找任何 ``` 包圍的內容
60
+ match = re.search(r"```\s*\n?(.*?)\n?```", response_text, re.DOTALL)
61
+ if match:
62
+ sql_candidate = match.group(1).strip()
63
+ if sql_candidate.upper().startswith('SELECT'):
64
+ return sql_candidate
65
+
66
+ # 3. 找 SQL 語句(更寬鬆的匹配)
67
+ match = re.search(r"(SELECT\s+.*?;)", response_text, re.DOTALL | re.IGNORECASE)
68
+ if match:
69
+ return match.group(1).strip()
70
+
71
+ # 4. 找沒有分號的 SQL
72
+ match = re.search(r"(SELECT\s+.*?)(?=\n\n|\n```|$|\n[^,\s])", response_text, re.DOTALL | re.IGNORECASE)
73
+ if match:
74
+ sql = match.group(1).strip()
75
+ if not sql.endswith(';'):
76
+ sql += ';'
77
+ return sql
78
+
79
+ # 5. 如果包含 SELECT,嘗試提取整行
80
+ if 'SELECT' in response_text.upper():
81
+ lines = response_text.split('\n')
82
+ for line in lines:
83
+ line = line.strip()
84
+ if line.upper().startswith('SELECT'):
85
+ if not line.endswith(';'):
86
+ line += ';'
87
+ return line
88
+
89
+ return None
90
 
91
  # ==================== Text-to-SQL 核心類 ====================
92
  class TextToSQLSystem:
 
94
  self.log_history = []
95
  self._log("初始化系統...")
96
  self.query_cache = {}
97
+
98
+ # 1. 載入嵌入模型
99
+ self._log(f"載入嵌入模型: {embed_model_name}")
100
+ self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
101
+ self.embed_model = AutoModel.from_pretrained(embed_model_name)
102
+ if DEVICE == "cuda":
103
+ self.embed_model = self.embed_model.cuda()
104
+
105
+ # 2. 載入數據庫結構
106
+ self.schema = self._load_schema()
107
+
108
+ # 3. 載入數據集並建立索引
109
+ self.dataset, self.faiss_index = self._load_and_index_dataset()
110
+
111
+ # 4. 載入 GGUF 模型(添加錯誤處理)
112
+ self._load_gguf_model()
113
+
114
+ self._log("✅ 系統初始化完成")
115
+ # 載入數據庫結構
116
+ self.schema = self._load_schema()
117
+
118
+ # 暫時添加:打印 schema 信息
119
+ if self.schema:
120
+ print("=" * 50)
121
+ print("數據庫 Schema 信息:")
122
+ for table_name, columns in self.schema.items():
123
+ print(f"\n表格: {table_name}")
124
+ print(f"欄位數: {len(columns)}")
125
+ print("欄位列表:")
126
+ for col in columns[:5]: # 只顯示前5個
127
+ print(f" - {col['name']} ({col['type']})")
128
+ print("=" * 50)
129
+
130
+ # in class TextToSQLSystem:
131
+
132
+ def _load_gguf_model(self):
133
+ """載入 GGUF 模型,使用更穩定、簡潔的參數"""
134
  try:
135
+ self._log("載入 GGUF 模型 (使用穩定性參數)...")
136
+ model_path = hf_hub_download(
137
+ repo_id=GGUF_REPO_ID,
138
+ filename=GGUF_FILENAME,
139
+ repo_type="dataset"
140
+ )
141
+
142
+ # 使用一組更基礎、更穩定的參數來載入模型
143
+ self.llm = Llama(
144
+ model_path=model_path,
145
+ n_ctx=2048, # 將上下文增加到 2048 以確保 Prompt 不會超長
146
+ n_threads=4, # 保持 4 線程
147
+ n_batch=512, # 建議值
148
+ verbose=False, # 設為 False 避免 llama.cpp 本身的日誌干擾
149
+ n_gpu_layers=0 # 確認在 CPU 上運行
150
+ )
151
+
152
+ # 簡單測試模型是否能回應
153
+ self.llm("你好", max_tokens=3)
154
+ self._log("✅ GGUF 模型載入成功")
155
+
156
  except Exception as e:
157
+ self._log(f"❌ GGUF 載入失敗: {e}", "ERROR")
158
+ self._log("系統將無法生成 SQL。請檢查模型檔案或 llama-cpp-python 安裝。", "CRITICAL")
159
+ self.llm = None
160
 
161
+ def _try_gguf_loading(self):
162
+ """嘗試載入 GGUF"""
163
+ try:
164
+ model_path = hf_hub_download(
165
+ repo_id=GGUF_REPO_ID,
166
+ filename=GGUF_FILENAME,
167
+ repo_type="dataset"
168
+ )
169
+
170
+ self.llm = Llama(
171
+ model_path=model_path,
172
+ n_ctx=512,
173
+ n_threads=4,
174
+ verbose=False,
175
+ n_gpu_layers=0
176
+ )
177
+
178
+ # 測試生成
179
+ test_result = self.llm("SELECT", max_tokens=5)
180
+ self._log("✅ GGUF 模型載入成功")
181
+ return True
182
+
183
+ except Exception as e:
184
+ self._log(f"GGUF 載入失敗: {e}", "WARNING")
185
+ return False
186
+
187
+ def _load_transformers_model(self):
188
+ """使用 Transformers 載入你的微調模型"""
189
  try:
190
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
191
+ import torch
192
+
193
+ self._log(f"載入 Transformers 模型: {FINETUNED_MODEL_PATH}")
194
+
195
+ # 載入你的微調模型
196
+ self.transformers_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL_PATH)
197
+ self.transformers_model = AutoModelForCausalLM.from_pretrained(
198
+ FINETUNED_MODEL_PATH,
199
+ torch_dtype=torch.float32, # CPU 使用 float32
200
+ device_map="cpu", # 強制使用 CPU
201
+ trust_remote_code=True # Qwen 模型可能需要
202
+ )
203
+
204
+ # 創建生成管道
205
+ self.generation_pipeline = pipeline(
206
+ "text-generation",
207
+ model=self.transformers_model,
208
+ tokenizer=self.transformers_tokenizer,
209
+ device=-1, # CPU
210
+ max_length=512,
211
+ do_sample=True,
212
+ temperature=0.1,
213
+ top_p=0.9,
214
+ pad_token_id=self.transformers_tokenizer.eos_token_id
215
+ )
216
+
217
+ self.llm = "transformers" # 標記使用 transformers
218
+ self._log("✅ Transformers 模型載入成功")
219
+
220
  except Exception as e:
221
+ self._log(f"❌ Transformers 載入也失敗: {e}", "ERROR")
222
+ self.llm = None
223
 
224
  def huggingface_api_call(self, prompt: str) -> str:
225
+ """調用 GGUF 模型,並加入詳細的原始輸出日誌"""
226
+ if self.llm is None:
227
+ self._log("模型未載入,返回 fallback SQL。", "ERROR")
228
+ return self._generate_fallback_sql(prompt)
229
+
230
+ try:
231
+ output = self.llm(
232
+ prompt,
233
+ max_tokens=150,
234
+ temperature=0.1,
235
+ top_p=0.9,
236
+ echo=False,
237
+ # --- 將 stop 參數加回來 ---
238
+ stop=["```", ";", "\n\n", "</s>"],
239
+ )
240
+
241
+ self._log(f"🧠 模型原始輸出 (Raw Output): {output}", "DEBUG")
242
+
243
+ if output and "choices" in output and len(output["choices"]) > 0:
244
+ generated_text = output["choices"][0]["text"]
245
+ self._log(f"📝 提取出的生成文本: {generated_text.strip()}", "DEBUG")
246
+ return generated_text.strip()
247
+ else:
248
+ self._log("❌ 模型的原始輸出格式不正確或為空。", "ERROR")
249
+ return ""
250
+
251
+ except Exception as e:
252
+ self._log(f"❌ 模型生成過程中發生嚴重錯誤: {e}", "CRITICAL")
253
+ import traceback
254
+ self._log(traceback.format_exc(), "DEBUG")
255
+ return ""
256
+
257
+ def _load_gguf_model_fallback(self, model_path):
258
+ """備用載入方式"""
259
  try:
260
+ # 嘗試不同的參數組合
261
+ self.llm = Llama(
262
+ model_path=model_path,
263
+ n_ctx=512, # 更小的上下文
264
+ n_threads=4,
265
+ n_batch=128,
266
+ vocab_only=False,
267
+ use_mmap=True,
268
+ use_mlock=False,
269
+ verbose=True
270
+ )
271
+ self._log("✅ 備用方式載入成功")
272
  except Exception as e:
273
+ self._log(f"❌ 備用方式也失敗: {e}", "ERROR")
274
+ self.llm = None
275
+
276
+ def _log(self, message: str, level: str = "INFO"):
277
+ self.log_history.append(format_log(message, level))
278
+ print(format_log(message, level))
279
+
280
  def _load_schema(self) -> Dict:
281
+ """載入數據庫結構"""
282
  try:
283
+ schema_path = hf_hub_download(
284
+ repo_id=DATASET_REPO_ID,
285
+ filename="sqlite_schema_FULL.json",
286
+ repo_type="dataset"
287
+ )
288
+ with open(schema_path, "r", encoding="utf-8") as f:
289
+ schema_data = json.load(f)
290
+
291
+ # 添加調試信息
292
+ self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格:")
293
+ for table_name, columns in schema_data.items():
294
+ self._log(f" - {table_name}: {len(columns)} 個欄位")
295
+ # 顯示前3個欄位作為範例
296
+ sample_cols = [col['name'] for col in columns[:3]]
297
+ self._log(f" 範例欄位: {', '.join(sample_cols)}")
298
+
299
+ self._log("✅ 數據庫結構載入完成")
300
+ return schema_data
301
+
302
  except Exception as e:
303
+ self._log(f"❌ 載入 schema 失敗: {e}", "ERROR")
304
+ return {}
305
+
306
+ # 也可以添加一個方法來檢查生成的 SQL 是否使用了正確的表格和欄位
307
+ def _analyze_sql_correctness(self, sql: str) -> Dict:
308
+ """分析 SQL 的正確性"""
309
+ analysis = {
310
+ 'valid_tables': [],
311
+ 'invalid_tables': [],
312
+ 'valid_columns': [],
313
+ 'invalid_columns': [],
314
+ 'suggestions': []
315
+ }
316
+
317
+ if not self.schema:
318
+ return analysis
319
+
320
+ # 提取 SQL 中的表格名稱
321
+ table_pattern = r'FROM\s+(\w+)|JOIN\s+(\w+)'
322
+ table_matches = re.findall(table_pattern, sql, re.IGNORECASE)
323
+ used_tables = [match[0] or match[1] for match in table_matches]
324
+
325
+ # 檢查表格是否存在
326
+ valid_tables = list(self.schema.keys())
327
+ for table in used_tables:
328
+ if table in valid_tables:
329
+ analysis['valid_tables'].append(table)
330
+ else:
331
+ analysis['invalid_tables'].append(table)
332
+ # 尋找相似的表格名稱
333
+ for valid_table in valid_tables:
334
+ if table.lower() in valid_table.lower() or valid_table.lower() in table.lower():
335
+ analysis['suggestions'].append(f"{table} -> {valid_table}")
336
+
337
+ # 提取欄位名稱(簡單版本)
338
+ column_pattern = r'SELECT\s+(.*?)\s+FROM|WHERE\s+(\w+)\s*[=<>]|GROUP BY\s+(\w+)|ORDER BY\s+(\w+)'
339
+ column_matches = re.findall(column_pattern, sql, re.IGNORECASE)
340
+
341
+ return analysis
342
 
343
  def _encode_texts(self, texts):
344
+ """編碼文本為嵌入向量"""
345
+ if isinstance(texts, str):
346
+ texts = [texts]
347
+
348
+ inputs = self.embed_tokenizer(texts, padding=True, truncation=True,
349
+ return_tensors="pt", max_length=512)
350
+ if DEVICE == "cuda":
351
+ inputs = {k: v.cuda() for k, v in inputs.items()}
352
+
353
  with torch.no_grad():
354
  outputs = self.embed_model(**inputs)
355
+
356
+ # 使用平均池化
357
+ embeddings = outputs.last_hidden_state.mean(dim=1)
358
+ return embeddings.cpu()
359
 
360
  def _load_and_index_dataset(self):
361
+ """載入數據集並建立 FAISS 索引"""
362
  try:
363
  dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
364
+
365
+ # 先過濾不完整樣本,避免 messages 長度不足導致索引或檢索報錯
366
+ try:
367
+ original_count = len(dataset)
368
+ except Exception:
369
+ original_count = None
370
+
371
+ dataset = dataset.filter(
372
+ lambda ex: isinstance(ex.get("messages"), list)
373
+ and len(ex["messages"]) >= 2
374
+ and all(
375
+ isinstance(m.get("content"), str) and m.get("content") and m["content"].strip()
376
+ for m in ex["messages"][:2]
377
+ )
378
+ )
379
+
380
+ if original_count is not None:
381
+ self._log(
382
+ f"資料集清理: 原始 {original_count} 筆, 過濾後 {len(dataset)} 筆, 移除 {original_count - len(dataset)} 筆"
383
+ )
384
+
385
+ if len(dataset) == 0:
386
+ self._log("清理後資料集為空,無法建立索引。", "ERROR")
387
+ return None, None
388
+
389
  corpus = [item['messages'][0]['content'] for item in dataset]
390
  self._log(f"正在編碼 {len(corpus)} 個問題...")
391
+
392
+ # 批量編碼
393
+ embeddings_list = []
394
+ batch_size = 32
395
+
396
+ for i in range(0, len(corpus), batch_size):
397
+ batch_texts = corpus[i:i+batch_size]
398
+ batch_embeddings = self._encode_texts(batch_texts)
399
+ embeddings_list.append(batch_embeddings)
400
+ self._log(f"已編碼 {min(i+batch_size, len(corpus))}/{len(corpus)}")
401
+
402
+ all_embeddings = torch.cat(embeddings_list, dim=0).numpy()
403
+
404
+ # 建立 FAISS 索引
405
  index = faiss.IndexFlatIP(all_embeddings.shape[1])
406
  index.add(all_embeddings.astype('float32'))
407
+
408
  self._log("✅ 向量索引建立完成")
409
  return dataset, index
410
+
411
  except Exception as e:
412
+ self._log(f"❌ 載入數據失敗: {e}", "ERROR")
413
+ return None, None
414
+
415
  def _identify_relevant_tables(self, question: str) -> List[str]:
416
+ """根據實際 Schema 識別相關表格"""
417
  question_lower = question.lower()
418
  relevant_tables = []
419
+
420
+ # 根據實際表格的關鍵詞映射
421
+ keyword_to_table = {
422
+ 'TSR53SampleDescription': ['客戶', '買方', '申請', '發票對象', 'customer', 'invoice', 'sample'],
423
+ 'JobsInProgress': ['進行中', '買家', '申請方', 'buyer', 'applicant', 'progress', '工作狀態'],
424
+ 'JobTimeline': ['時間', '完成', '創建', '實驗室', 'timeline', 'creation', 'lab'],
425
+ 'TSR53Invoice': ['發票', '金額', '費用', 'invoice', 'credit', 'amount'],
426
+ 'JobEventsLog': ['事件', '操作', '用戶', 'event', 'log', 'user'],
427
+ 'calendar_days': ['工作日', '假期', 'workday', 'holiday', 'calendar']
428
+ }
429
+
430
  for table, keywords in keyword_to_table.items():
431
+ if any(keyword in question_lower for keyword in keywords):
432
+ relevant_tables.append(table)
433
+
434
+ # 預設重要表格
435
+ if not relevant_tables:
436
+ if any(word in question_lower for word in ['客戶', '買家', '申請', '工作單', '數量']):
437
+ return ['TSR53SampleDescription', 'JobsInProgress']
438
+ else:
439
+ return ['JobTimeline', 'TSR53SampleDescription']
440
+
441
+ return relevant_tables[:3] # 最多返回3個相關表格
442
+
443
+ # 請將這整個函數複製到您的 TextToSQLSystem class 內部
444
 
445
  def _format_relevant_schema(self, table_names: List[str]) -> str:
446
+ """
447
+ 生成一個簡化的、不易被模型錯誤模仿的 Schema 字符串。
448
+ """
449
+ if not self.schema:
450
+ return "No schema available.\n"
451
+
452
+ actual_table_names_map = {name.lower(): name for name in self.schema.keys()}
453
+ real_table_names = []
454
  for table in table_names:
455
+ actual_name = actual_table_names_map.get(table.lower())
456
+ if actual_name:
457
+ real_table_names.append(actual_name)
458
+ elif table in self.schema:
459
+ real_table_names.append(table)
460
+
461
+ if not real_table_names:
462
+ self._log("未識別到相關表格,使用預設核心表格。", "WARNING")
463
+ real_table_names = ['TSR53SampleDescription', 'JobTimeline', 'JobsInProgress']
464
+
465
+ formatted = ""
466
+ for table in real_table_names:
467
  if table in self.schema:
468
+ # 使用簡單的 "Table: ..." 和 "Columns: ..." 格式
469
  formatted += f"Table: {table}\n"
470
  cols_str = []
471
+ # 只顯示前 10 個關鍵欄位
472
  for col in self.schema[table][:10]:
473
+ col_name = col['name']
474
+ col_type = col['type']
475
+ col_desc = col.get('description', '').replace('\n', ' ')
476
+ # 將描述信息放在括號裡
477
+ if col_desc:
478
+ cols_str.append(f"{col_name} ({col_type}, {col_desc})")
479
+ else:
480
+ cols_str.append(f"{col_name} ({col_type})")
481
  formatted += f"Columns: {', '.join(cols_str)}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
+ return formatted.strip()
 
484
 
485
+ # in class TextToSQLSystem:
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
  def _validate_and_fix_sql(self, question: str, raw_response: str) -> Tuple[Optional[str], str]:
488
+ """
489
+ (V23 / 统一实体识别版)
490
+ 一個全面、多層次的 SQL 驗證與生成引擎。
491
+ 引入了全新的、统一的实体识别引擎,能够准确解析 "买家 Gap", "c0761n",
492
+ "买家ID c0761n" 等多种复杂的实体提问模式。
493
+ """
494
  q_lower = question.lower()
495
+
496
+ # ==============================================================================
497
+ # 第一層:高價值意圖識別與模板覆寫 (Intent Recognition & Templating)
498
+ # ==============================================================================
499
+
500
+ # --- **全新的统一实体识别引擎** ---
501
  entity_match_data = None
502
+
503
+ # 定义多种识别模式,【优先级从高到低】
504
  entity_patterns = [
505
+ # 模式1: 匹配 "类型 + ID" (e.g., "买家ID C0761N") - 最高优先级
506
  {'pattern': r"(买家|buyer)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.BuyerID', 'type': '买家ID'},
507
  {'pattern': r"(申请方|申请厂商|applicant)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.ApplicantID', 'type': '申请方ID'},
508
+ {'pattern': r"(付款方|付款厂商|invoiceto)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.InvoiceToID', 'type': '付款方ID'},
509
+ {'pattern': r"(代理商|agent)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.AgentID', 'type': '代理商ID'},
510
+
511
+ # 模式2: 匹配 "类型 + 名称" (e.g., "买家 Gap")
512
  {'pattern': r"(买家|buyer|客戶)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.BuyerName', 'type': '买家'},
513
  {'pattern': r"(申请方|申请厂商|applicant)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.ApplicantName', 'type': '申请方'},
514
+ {'pattern': r"(付款方|付款厂商|invoiceto)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.InvoiceToName', 'type': '付款方'},
515
+ {'pattern': r"(代理商|agent)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.AgentName', 'type': '代理商'},
516
+
517
+ # 模式3: 单独匹配一个 ID (e.g., "c0761n") - 较低优先级
518
  {'pattern': r"\b([A-Z]\d{4}[A-Z])\b", 'column': 'sd.ApplicantID', 'type': 'ID'}
519
  ]
520
+
521
  for p in entity_patterns:
522
  match = re.search(p['pattern'], question, re.IGNORECASE)
523
  if match:
524
  entity_value = match.group(2) if len(match.groups()) > 1 else match.group(1)
525
+ entity_match_data = {
526
+ "type": p['type'],
527
+ "name": entity_value.strip().upper(),
528
+ "column": p['column']
529
+ }
530
  break
531
 
532
+ # --- 预先检测其他意图 ---
533
+ job_no_match = re.search(r"(?:工單|jobno)\s*'\"?([A-Z]{2,3}\d+)'\"?", question, re.IGNORECASE)
534
+
535
+ # --- 判断逻辑: 依优先级进入对应的模板 ---
536
+ if any(kw in q_lower for kw in ['報告號碼', '報告清單', '列出報告', 'report number', 'list of reports']):
537
+ year_match = re.search(r'(\d{4})\s*年?', question)
538
+ month_match = re.search(r'(\d{1,2})\s*月', question)
539
+ from_clause = "FROM JobTimeline AS jt"
540
+ select_clause = "SELECT jt.JobNo, jt.ReportAuthorization"
541
+ where_conditions = ["jt.ReportAuthorization IS NOT NULL"]
542
+ log_parts = []
543
+
544
+ if year_match: year = year_match.group(1); where_conditions.append(f"strftime('%Y', jt.ReportAuthorization) = '{year}'"); log_parts.append(f"{year}年")
545
+ if month_match: month = month_match.group(1).zfill(2); where_conditions.append(f"strftime('%m', jt.ReportAuthorization) = '{month}'"); log_parts.append(f"{month}月")
546
+
547
+ if 'fail' in q_lower or '失敗' in q_lower:
548
+ if "JOIN TSR53SampleDescription" not in from_clause: from_clause = "FROM JobTimeline AS jt JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
549
+ where_conditions.append("sd.OverallRating = 'Fail'"); log_parts.append("Fail")
550
+ elif 'pass' in q_lower or '通過' in q_lower:
551
+ if "JOIN TSR53SampleDescription" not in from_clause: from_clause = "FROM JobTimeline AS jt JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
552
+ where_conditions.append("sd.OverallRating = 'Pass'"); log_parts.append("Pass")
553
+
554
  if entity_match_data:
555
+ entity_name, column_name = entity_match_data["name"], entity_match_data["column"]
556
+ if "JOIN TSR53SampleDescription" not in from_clause: from_clause = "FROM JobTimeline AS jt JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
557
+ match_operator = "=" if column_name.endswith("ID") else "LIKE"
558
+ entity_value = f"'{entity_name}'" if match_operator == "=" else f"'%{entity_name}%'"
559
+ where_conditions.append(f"{column_name} {match_operator} {entity_value}")
560
+ log_parts.append(entity_name)
561
+ select_clause = "SELECT jt.JobNo, sd.BuyerName, jt.ReportAuthorization"
562
+
563
+ final_where_clause = "WHERE " + " AND ".join(where_conditions)
564
+ time_log = " ".join(log_parts) if log_parts else "全部"
565
+ self._log(f"🔄 檢測到查詢【{time_log} 報告列表】意圖,啟用智能模板。", "INFO")
566
+ template_sql = f"{select_clause} {from_clause} {final_where_clause} ORDER BY jt.ReportAuthorization DESC;"
567
+ return self._finalize_sql(template_sql, f"模板覆寫: {time_log} 報告列表查詢")
568
+
569
+ # ... (此处可以继续添加 V17 版本中的其他所有 if/elif 模板)
570
+ elif '報告' in q_lower and any(kw in q_lower for kw in ['幾份', '多少', '數量', '總數']) and not entity_match_data:
571
  year_match = re.search(r'(\d{4})\s*年?', question)
572
+ time_condition, time_log = "", ""
573
+ if year_match:
574
+ year = year_match.group(1)
575
+ time_condition = f"WHERE ReportAuthorization IS NOT NULL AND strftime('%Y', ReportAuthorization) = '{year}'"
576
+ time_log = f"{year}"
577
+ else:
578
+ time_condition = "WHERE ReportAuthorization IS NOT NULL"
579
+ self._log(f"🔄 檢測到查詢【{time_log}全局報告總數】意圖,啟用模板。", "INFO")
580
+ template_sql = f"SELECT COUNT(DISTINCT JobNo) AS report_count FROM JobTimeline {time_condition};"
581
+ return self._finalize_sql(template_sql, f"模板覆寫: {time_log}全局報告總數查詢")
582
+
583
+ # ==============================================================================
584
+ # 第二层:常规修正流程 (Fallback Corrections)
585
+ # ==============================================================================
586
+ self._log("未觸發任何模板,嘗試解析並修正 AI 輸出...", "INFO")
587
 
 
 
588
  parsed_sql = parse_sql_from_response(raw_response)
589
+ if not parsed_sql:
590
+ self._log(f"❌ 未能從模型回應中解析出任何 SQL。原始回應: {raw_response}", "ERROR")
591
+ return None, f"無法解析SQL。原始回應:\n{raw_response}"
592
+
593
+ self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG")
594
+
595
+ fixed_sql = " " + parsed_sql.strip() + " "
596
+ fixes_applied_fallback = []
597
+
598
+ dialect_corrections = {r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)"}
599
+ for pattern, replacement in dialect_corrections.items():
600
+ if re.search(pattern, fixed_sql, re.IGNORECASE):
601
+ fixed_sql = re.sub(pattern, replacement, fixed_sql, flags=re.IGNORECASE)
602
+ fixes_applied_fallback.append(f"修正方言: {pattern}")
603
+
604
+ schema_corrections = {'TSR53Report':'TSR53SampleDescription', 'TSR53InvoiceReportNo':'JobNo', 'TSR53ReportNo':'JobNo', 'TSR53InvoiceNo':'JobNo', 'TSR53InvoiceCreditNoteNo':'InvoiceCreditNoteNo', 'TSR53InvoiceLocalAmount':'LocalAmount', 'Status':'OverallRating', 'ReportStatus':'OverallRating'}
605
+ for wrong, correct in schema_corrections.items():
606
+ pattern = r'\b' + re.escape(wrong) + r'\b'
607
+ if re.search(pattern, fixed_sql, re.IGNORECASE):
608
+ fixed_sql = re.sub(pattern, correct, fixed_sql, flags=re.IGNORECASE)
609
+ fixes_applied_fallback.append(f"映射 Schema: '{wrong}' -> '{correct}'")
610
+
611
+ log_msg = "AI 生成並成功修正" if fixes_applied_fallback else "AI 生成且無需修正"
612
+ return self._finalize_sql(fixed_sql, log_msg)
613
+
614
+ def _finalize_sql(self, sql: str, log_message: str) -> Tuple[str, str]:
615
+ """一個輔助函數,用於清理最終的SQL並記錄成功日誌。"""
616
+ final_sql = sql.strip()
617
+ if not final_sql.endswith(';'):
618
+ final_sql += ';'
619
+ final_sql = re.sub(r'\s+', ' ', final_sql).strip()
620
+ self._log(f"✅ SQL 已生成 ({log_message})", "INFO")
621
+ self._log(f" - 最終 SQL: {final_sql}", "DEBUG")
622
+ return final_sql, "生成成功"
623
+
624
+ def find_most_similar(self, question: str, top_k: int) -> List[Dict]:
625
+ """使用 FAISS 快速檢索相似問題"""
626
+ if self.faiss_index is None or self.dataset is None:
627
+ return []
628
+
629
+ try:
630
+ # 編碼問題
631
+ q_embedding = self._encode_texts([question]).numpy().astype('float32')
632
+
633
+ # FAISS 搜索
634
+ distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset)))
635
+
636
+ results = []
637
+ seen_questions = set()
638
+
639
+ for i, idx in enumerate(indices[0]):
640
+ if len(results) >= top_k:
641
+ break
642
+
643
+ # 修復:將 numpy.int64 轉換為 Python int
644
+ idx = int(idx) # ← 添加這行轉換
645
+
646
+ if idx >= len(self.dataset): # 確保索引有效
647
+ continue
648
+
649
+ item = self.dataset[idx]
650
+ # 防呆:若樣本不完整則跳過
651
+ if not isinstance(item.get('messages'), list) or len(item['messages']) < 2:
652
+ continue
653
+ q_content = (item['messages'][0].get('content') or '').strip()
654
+ a_content = (item['messages'][1].get('content') or '').strip()
655
+ if not q_content or not a_content:
656
+ continue
657
+
658
+ # 提取純淨問題
659
+ clean_q = re.sub(r"以下是一個SQL查詢任務:\s*指令:\s*", "", q_content).strip()
660
+ if clean_q in seen_questions:
661
+ continue
662
+
663
+ seen_questions.add(clean_q)
664
+ sql = parse_sql_from_response(a_content) or "無法解析範例SQL"
665
+
666
+ results.append({
667
+ "similarity": float(distances[0][i]),
668
+ "question": clean_q,
669
+ "sql": sql
670
+ })
671
+
672
+ return results
673
+
674
+ except Exception as e:
675
+ self._log(f"❌ 檢索失敗: {e}", "ERROR")
676
+ return []
677
+
678
+ # in class TextToSQLSystem:
679
+
680
+ def _build_prompt(self, user_q: str, examples: List[Dict]) -> str:
681
+ """
682
+ 建立一個高度結構化、以任務為導向的提示詞,使用清晰的標��分隔符。
683
+ """
684
+ relevant_tables = self._identify_relevant_tables(user_q)
685
+
686
+ # 使用我們新的、更簡單的 schema 格式化函數
687
+ schema_str = self._format_relevant_schema(relevant_tables)
688
+
689
+ example_str = "No example available."
690
+ if examples:
691
+ best_example = examples[0]
692
+ example_str = f"Question: {best_example['question']}\nSQL:\n```sql\n{best_example['sql']}\n```"
693
+
694
+ # 使用強分隔符和清晰的標題來構建 prompt
695
+ prompt = f"""### INSTRUCTIONS ###
696
+ You are a SQLite expert. Your only job is to generate a single, valid SQLite query based on the provided schema and question.
697
+ - ONLY use the tables and columns from the schema below.
698
+ - ALWAYS use SQLite syntax (e.g., `strftime('%Y', date_column)` for years).
699
+ - The report completion date is the `ReportAuthorization` column in the `JobTimeline` table.
700
+ - Your output MUST be ONLY the SQL query inside a ```sql code block.
701
+
702
+ ### SCHEMA ###
703
+ {schema_str}
704
+
705
+ ### EXAMPLE ###
706
+ {example_str}
707
+
708
+ ### TASK ###
709
+ Generate a SQLite query for the following question.
710
+ Question: {user_q}
711
+ SQL:
712
+ ```sql
713
+ """
714
+ self._log(f"📏 Prompt 長度: {len(prompt)} 字符")
715
+ # 不再需要複雜的長度截斷邏輯,因為 schema 已經被簡化
716
+ return prompt
717
+
718
+
719
+ def _generate_fallback_sql(self, prompt: str) -> str:
720
+ """當模型不可用時的備用 SQL 生成"""
721
+ prompt_lower = prompt.lower()
722
+
723
+ # 簡單的關鍵詞匹配生成基本 SQL
724
+ if "統計" in prompt or "數量" in prompt or "多少" in prompt:
725
+ if "月" in prompt:
726
+ return "SELECT strftime('%Y-%m', completed_time) as month, COUNT(*) as count FROM jobtimeline GROUP BY month ORDER BY month;"
727
+ elif "客戶" in prompt:
728
+ return "SELECT applicant, COUNT(*) as count FROM tsr53sampledescription GROUP BY applicant ORDER BY count DESC;"
729
+ else:
730
+ return "SELECT COUNT(*) as total_count FROM jobtimeline WHERE completed_time IS NOT NULL;"
731
+
732
+ elif "金額" in prompt or "總額" in prompt:
733
+ return "SELECT SUM(amount) as total_amount FROM tsr53invoice;"
734
+
735
+ elif "評級" in prompt or "pass" in prompt_lower or "fail" in prompt_lower:
736
+ return "SELECT rating, COUNT(*) as count FROM tsr53sampledescription GROUP BY rating;"
737
+
738
+ else:
739
+ return "SELECT * FROM jobtimeline LIMIT 10;"
740
+
741
+ def _validate_model_file(self, model_path):
742
+ """驗證模型檔案完整性"""
743
+ try:
744
+ if not os.path.exists(model_path):
745
+ return False
746
+
747
+ # 檢查檔案大小(至少應該有幾MB)
748
+ file_size = os.path.getsize(model_path)
749
+ if file_size < 10 * 1024 * 1024: # 小於 10MB 可能有問題
750
+ return False
751
+
752
+ # 檢查 GGUF 檔案頭部
753
+ with open(model_path, 'rb') as f:
754
+ header = f.read(8)
755
+ if not header.startswith(b'GGUF'):
756
+ return False
757
+
758
+ return True
759
+ except Exception:
760
+ return False
761
+
762
+ # in class TextToSQLSystem:
763
 
764
  def process_question(self, question: str) -> Tuple[str, str]:
765
+ """處理使用者問題 (V2 / 最終版)"""
766
+ # 檢查緩存
767
+ if question in self.query_cache:
768
+ self._log("⚡ 使用緩存結果")
769
+ return self.query_cache[question]
770
+
771
  self.log_history = []
772
  self._log(f"⏰ 處理問題: {question}")
773
+
774
+ # 1. 檢索相似範例
775
+ self._log("🔍 尋找相似範例...")
776
  examples = self.find_most_similar(question, FEW_SHOT_EXAMPLES_COUNT)
777
  if examples: self._log(f"✅ 找到 {len(examples)} 個相似範例")
778
+
779
+ # 2. 建立提示詞
780
+ self._log("📝 建立 Prompt...")
781
  prompt = self._build_prompt(question, examples)
782
+
783
+ # 3. 生成 AI 回應
784
  self._log("🧠 開始生成 AI 回應...")
785
  response = self.huggingface_api_call(prompt)
786
+
787
+ # 4. **新的核心步驟**: 呼叫決策引擎來生成最終 SQL
788
  final_sql, status_message = self._validate_and_fix_sql(question, response)
789
+
790
+ if final_sql:
791
+ result = (final_sql, status_message)
792
+ else:
793
+ result = (status_message, "生成失敗")
794
+
795
+ # 緩存結果
796
  self.query_cache[question] = result
797
  return result
798
 
799
  # ==================== Gradio 介面 ====================
800
  text_to_sql_system = TextToSQLSystem()
801
+
802
  def process_query(q: str):
803
+ if not q.strip():
804
+ return "", "等待輸入", "請輸入問題"
805
+
806
  sql, status = text_to_sql_system.process_question(q)
807
+ logs = "\n".join(text_to_sql_system.log_history[-10:]) # 只顯示最後10條日誌
808
+
809
+ return sql, status, logs
810
+
811
+ # 範例問題
812
+ examples = [
813
+ "2024年每月完成多少份報告?",
814
+ "統計各種評級(Pass/Fail)的分布情況",
815
+ "找出總金額最高的10個工作單",
816
+ "哪些客戶的工作單數量最多?",
817
+ "A組昨天完成了多少個測試項目?"
818
+ ]
819
 
 
820
  with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
821
+ gr.Markdown("# ⚡ Text-to-SQL 智能助手")
822
+ gr.Markdown("輸入自然語言問題,自動生成SQL查詢語句")
823
+
824
  with gr.Row():
825
  with gr.Column(scale=2):
826
  inp = gr.Textbox(lines=3, label="💬 您的問題", placeholder="例如:2024年每月完成多少份報告?")
827
  btn = gr.Button("🚀 生成 SQL", variant="primary")
828
  status = gr.Textbox(label="狀態", interactive=False)
829
+
830
+ with gr.Column(scale=3):
831
+ sql_out = gr.Code(label="🤖 生成的 SQL", language="sql", lines=8)
832
+
833
+ with gr.Accordion("📋 處理日誌", open=False):
834
+ logs = gr.Textbox(lines=8, label="日誌", interactive=False)
835
+
836
+ # 範例區
837
+ gr.Examples(
838
+ examples=examples,
839
+ inputs=inp,
840
+ label="💡 點擊試用範例問題"
841
+ )
842
+
843
+ # 綁定事件
844
  btn.click(process_query, inputs=[inp], outputs=[sql_out, status, logs])
845
  inp.submit(process_query, inputs=[inp], outputs=[sql_out, status, logs])
846
 
847
  if __name__ == "__main__":
848
+ demo.launch(
849
+ server_name="0.0.0.0",
850
+ server_port=7860,
851
+ share=False
852
+ )