Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -46,27 +46,27 @@ def parse_sql_from_response(response_text: str) -> Optional[str]:
|
|
| 46 |
"""從模型輸出提取 SQL,增強版"""
|
| 47 |
if not response_text:
|
| 48 |
return None
|
| 49 |
-
|
| 50 |
# 清理回應文本
|
| 51 |
response_text = response_text.strip()
|
| 52 |
-
|
| 53 |
# 1. 先找 ```sql ... ```
|
| 54 |
match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
|
| 55 |
if match:
|
| 56 |
return match.group(1).strip()
|
| 57 |
-
|
| 58 |
# 2. 找任何 ``` 包圍的內容
|
| 59 |
match = re.search(r"```\s*\n?(.*?)\n?```", response_text, re.DOTALL)
|
| 60 |
if match:
|
| 61 |
sql_candidate = match.group(1).strip()
|
| 62 |
if sql_candidate.upper().startswith('SELECT'):
|
| 63 |
return sql_candidate
|
| 64 |
-
|
| 65 |
# 3. 找 SQL 語句(更寬鬆的匹配)
|
| 66 |
match = re.search(r"(SELECT\s+.*?;)", response_text, re.DOTALL | re.IGNORECASE)
|
| 67 |
if match:
|
| 68 |
return match.group(1).strip()
|
| 69 |
-
|
| 70 |
# 4. 找沒有分號的 SQL
|
| 71 |
match = re.search(r"(SELECT\s+.*?)(?=\n\n|\n```|$|\n[^,\s])", response_text, re.DOTALL | re.IGNORECASE)
|
| 72 |
if match:
|
|
@@ -74,7 +74,7 @@ def parse_sql_from_response(response_text: str) -> Optional[str]:
|
|
| 74 |
if not sql.endswith(';'):
|
| 75 |
sql += ';'
|
| 76 |
return sql
|
| 77 |
-
|
| 78 |
# 5. 如果包含 SELECT,嘗試提取整行
|
| 79 |
if 'SELECT' in response_text.upper():
|
| 80 |
lines = response_text.split('\n')
|
|
@@ -84,7 +84,7 @@ def parse_sql_from_response(response_text: str) -> Optional[str]:
|
|
| 84 |
if not line.endswith(';'):
|
| 85 |
line += ';'
|
| 86 |
return line
|
| 87 |
-
|
| 88 |
return None
|
| 89 |
|
| 90 |
# ==================== Text-to-SQL 核心類 ====================
|
|
@@ -113,7 +113,7 @@ class TextToSQLSystem:
|
|
| 113 |
self._log("✅ 系統初始化完成")
|
| 114 |
# 載入數據庫結構
|
| 115 |
self.schema = self._load_schema()
|
| 116 |
-
|
| 117 |
# 暫時添加:打印 schema 信息
|
| 118 |
if self.schema:
|
| 119 |
print("=" * 50)
|
|
@@ -125,7 +125,7 @@ class TextToSQLSystem:
|
|
| 125 |
for col in columns[:5]: # 只顯示前5個
|
| 126 |
print(f" - {col['name']} ({col['type']})")
|
| 127 |
print("=" * 50)
|
| 128 |
-
|
| 129 |
# in class TextToSQLSystem:
|
| 130 |
|
| 131 |
def _load_gguf_model(self):
|
|
@@ -137,7 +137,7 @@ class TextToSQLSystem:
|
|
| 137 |
filename=GGUF_FILENAME,
|
| 138 |
repo_type="dataset"
|
| 139 |
)
|
| 140 |
-
|
| 141 |
# 使用一組更基礎、更穩定的參數來載入模型
|
| 142 |
self.llm = Llama(
|
| 143 |
model_path=model_path,
|
|
@@ -147,16 +147,16 @@ class TextToSQLSystem:
|
|
| 147 |
verbose=False, # 設為 False 避免 llama.cpp 本身的日誌干擾
|
| 148 |
n_gpu_layers=0 # 確認在 CPU 上運行
|
| 149 |
)
|
| 150 |
-
|
| 151 |
# 簡單測試模型是否能回應
|
| 152 |
self.llm("你好", max_tokens=3)
|
| 153 |
self._log("✅ GGUF 模型載入成功")
|
| 154 |
-
|
| 155 |
except Exception as e:
|
| 156 |
self._log(f"❌ GGUF 載入失敗: {e}", "ERROR")
|
| 157 |
self._log("系統將無法生成 SQL。請檢查模型檔案或 llama-cpp-python 安裝。", "CRITICAL")
|
| 158 |
self.llm = None
|
| 159 |
-
|
| 160 |
def _try_gguf_loading(self):
|
| 161 |
"""嘗試載入 GGUF"""
|
| 162 |
try:
|
|
@@ -165,7 +165,7 @@ class TextToSQLSystem:
|
|
| 165 |
filename=GGUF_FILENAME,
|
| 166 |
repo_type="dataset"
|
| 167 |
)
|
| 168 |
-
|
| 169 |
self.llm = Llama(
|
| 170 |
model_path=model_path,
|
| 171 |
n_ctx=512,
|
|
@@ -173,24 +173,24 @@ class TextToSQLSystem:
|
|
| 173 |
verbose=False,
|
| 174 |
n_gpu_layers=0
|
| 175 |
)
|
| 176 |
-
|
| 177 |
# 測試生成
|
| 178 |
test_result = self.llm("SELECT", max_tokens=5)
|
| 179 |
self._log("✅ GGUF 模型載入成功")
|
| 180 |
return True
|
| 181 |
-
|
| 182 |
except Exception as e:
|
| 183 |
self._log(f"GGUF 載入失敗: {e}", "WARNING")
|
| 184 |
return False
|
| 185 |
-
|
| 186 |
def _load_transformers_model(self):
|
| 187 |
"""使用 Transformers 載入你的微調模型"""
|
| 188 |
try:
|
| 189 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 190 |
import torch
|
| 191 |
-
|
| 192 |
self._log(f"載入 Transformers 模型: {FINETUNED_MODEL_PATH}")
|
| 193 |
-
|
| 194 |
# 載入你的微調模型
|
| 195 |
self.transformers_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL_PATH)
|
| 196 |
self.transformers_model = AutoModelForCausalLM.from_pretrained(
|
|
@@ -199,7 +199,7 @@ class TextToSQLSystem:
|
|
| 199 |
device_map="cpu", # 強制使用 CPU
|
| 200 |
trust_remote_code=True # Qwen 模型可能需要
|
| 201 |
)
|
| 202 |
-
|
| 203 |
# 創建生成管道
|
| 204 |
self.generation_pipeline = pipeline(
|
| 205 |
"text-generation",
|
|
@@ -212,20 +212,20 @@ class TextToSQLSystem:
|
|
| 212 |
top_p=0.9,
|
| 213 |
pad_token_id=self.transformers_tokenizer.eos_token_id
|
| 214 |
)
|
| 215 |
-
|
| 216 |
self.llm = "transformers" # 標記使用 transformers
|
| 217 |
self._log("✅ Transformers 模型��入成功")
|
| 218 |
-
|
| 219 |
except Exception as e:
|
| 220 |
self._log(f"❌ Transformers 載入也失敗: {e}", "ERROR")
|
| 221 |
self.llm = None
|
| 222 |
-
|
| 223 |
def huggingface_api_call(self, prompt: str) -> str:
|
| 224 |
"""調用 GGUF 模型,並加入詳細的原始輸出日誌"""
|
| 225 |
if self.llm is None:
|
| 226 |
self._log("模型未載入,返回 fallback SQL。", "ERROR")
|
| 227 |
return self._generate_fallback_sql(prompt)
|
| 228 |
-
|
| 229 |
try:
|
| 230 |
output = self.llm(
|
| 231 |
prompt,
|
|
@@ -236,9 +236,9 @@ class TextToSQLSystem:
|
|
| 236 |
# --- 將 stop 參數加回來 ---
|
| 237 |
stop=["```", ";", "\n\n", "</s>"],
|
| 238 |
)
|
| 239 |
-
|
| 240 |
self._log(f"🧠 模型原始輸出 (Raw Output): {output}", "DEBUG")
|
| 241 |
-
|
| 242 |
if output and "choices" in output and len(output["choices"]) > 0:
|
| 243 |
generated_text = output["choices"][0]["text"]
|
| 244 |
self._log(f"📝 提取出的生成文本: {generated_text.strip()}", "DEBUG")
|
|
@@ -246,13 +246,13 @@ class TextToSQLSystem:
|
|
| 246 |
else:
|
| 247 |
self._log("❌ 模型的原始輸出格式不正確或為空。", "ERROR")
|
| 248 |
return ""
|
| 249 |
-
|
| 250 |
except Exception as e:
|
| 251 |
self._log(f"❌ 模型生成過程中發生嚴重錯誤: {e}", "CRITICAL")
|
| 252 |
import traceback
|
| 253 |
self._log(traceback.format_exc(), "DEBUG")
|
| 254 |
return ""
|
| 255 |
-
|
| 256 |
def _load_gguf_model_fallback(self, model_path):
|
| 257 |
"""備用載入方式"""
|
| 258 |
try:
|
|
@@ -286,7 +286,7 @@ class TextToSQLSystem:
|
|
| 286 |
)
|
| 287 |
with open(schema_path, "r", encoding="utf-8") as f:
|
| 288 |
schema_data = json.load(f)
|
| 289 |
-
|
| 290 |
# 添加調試信息
|
| 291 |
self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格:")
|
| 292 |
for table_name, columns in schema_data.items():
|
|
@@ -294,14 +294,14 @@ class TextToSQLSystem:
|
|
| 294 |
# 顯示前3個欄位作為範例
|
| 295 |
sample_cols = [col['name'] for col in columns[:3]]
|
| 296 |
self._log(f" 範例欄位: {', '.join(sample_cols)}")
|
| 297 |
-
|
| 298 |
self._log("✅ 數據庫結構載入完成")
|
| 299 |
return schema_data
|
| 300 |
-
|
| 301 |
except Exception as e:
|
| 302 |
self._log(f"❌ 載入 schema 失敗: {e}", "ERROR")
|
| 303 |
return {}
|
| 304 |
-
|
| 305 |
# 也可以添加一個方法來檢查生成的 SQL 是否使用了正確的表格和欄位
|
| 306 |
def _analyze_sql_correctness(self, sql: str) -> Dict:
|
| 307 |
"""分析 SQL 的正確性"""
|
|
@@ -312,15 +312,15 @@ class TextToSQLSystem:
|
|
| 312 |
'invalid_columns': [],
|
| 313 |
'suggestions': []
|
| 314 |
}
|
| 315 |
-
|
| 316 |
if not self.schema:
|
| 317 |
return analysis
|
| 318 |
-
|
| 319 |
# 提取 SQL 中的表格名稱
|
| 320 |
table_pattern = r'FROM\s+(\w+)|JOIN\s+(\w+)'
|
| 321 |
table_matches = re.findall(table_pattern, sql, re.IGNORECASE)
|
| 322 |
used_tables = [match[0] or match[1] for match in table_matches]
|
| 323 |
-
|
| 324 |
# 檢查表格是否存在
|
| 325 |
valid_tables = list(self.schema.keys())
|
| 326 |
for table in used_tables:
|
|
@@ -332,26 +332,26 @@ class TextToSQLSystem:
|
|
| 332 |
for valid_table in valid_tables:
|
| 333 |
if table.lower() in valid_table.lower() or valid_table.lower() in table.lower():
|
| 334 |
analysis['suggestions'].append(f"{table} -> {valid_table}")
|
| 335 |
-
|
| 336 |
# 提取欄位名稱(簡單版本)
|
| 337 |
column_pattern = r'SELECT\s+(.*?)\s+FROM|WHERE\s+(\w+)\s*[=<>]|GROUP BY\s+(\w+)|ORDER BY\s+(\w+)'
|
| 338 |
column_matches = re.findall(column_pattern, sql, re.IGNORECASE)
|
| 339 |
-
|
| 340 |
return analysis
|
| 341 |
|
| 342 |
def _encode_texts(self, texts):
|
| 343 |
"""編碼文本為嵌入向量"""
|
| 344 |
if isinstance(texts, str):
|
| 345 |
texts = [texts]
|
| 346 |
-
|
| 347 |
-
inputs = self.embed_tokenizer(texts, padding=True, truncation=True,
|
| 348 |
return_tensors="pt", max_length=512)
|
| 349 |
if DEVICE == "cuda":
|
| 350 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 351 |
-
|
| 352 |
with torch.no_grad():
|
| 353 |
outputs = self.embed_model(**inputs)
|
| 354 |
-
|
| 355 |
# 使用平均池化
|
| 356 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 357 |
return embeddings.cpu()
|
|
@@ -360,28 +360,53 @@ class TextToSQLSystem:
|
|
| 360 |
"""載入數據集並建立 FAISS 索引"""
|
| 361 |
try:
|
| 362 |
dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
corpus = [item['messages'][0]['content'] for item in dataset]
|
| 364 |
self._log(f"正在編碼 {len(corpus)} 個問題...")
|
| 365 |
-
|
| 366 |
# 批量編碼
|
| 367 |
embeddings_list = []
|
| 368 |
batch_size = 32
|
| 369 |
-
|
| 370 |
for i in range(0, len(corpus), batch_size):
|
| 371 |
batch_texts = corpus[i:i+batch_size]
|
| 372 |
batch_embeddings = self._encode_texts(batch_texts)
|
| 373 |
embeddings_list.append(batch_embeddings)
|
| 374 |
self._log(f"已編碼 {min(i+batch_size, len(corpus))}/{len(corpus)}")
|
| 375 |
-
|
| 376 |
all_embeddings = torch.cat(embeddings_list, dim=0).numpy()
|
| 377 |
-
|
| 378 |
# 建立 FAISS 索引
|
| 379 |
index = faiss.IndexFlatIP(all_embeddings.shape[1])
|
| 380 |
index.add(all_embeddings.astype('float32'))
|
| 381 |
-
|
| 382 |
self._log("✅ 向量索引建立完成")
|
| 383 |
return dataset, index
|
| 384 |
-
|
| 385 |
except Exception as e:
|
| 386 |
self._log(f"❌ 載入數據失敗: {e}", "ERROR")
|
| 387 |
return None, None
|
|
@@ -390,7 +415,7 @@ class TextToSQLSystem:
|
|
| 390 |
"""根據實際 Schema 識別相關表格"""
|
| 391 |
question_lower = question.lower()
|
| 392 |
relevant_tables = []
|
| 393 |
-
|
| 394 |
# 根據實際表格的關鍵詞映射
|
| 395 |
keyword_to_table = {
|
| 396 |
'TSR53SampleDescription': ['客戶', '買方', '申請', '發票對象', 'customer', 'invoice', 'sample'],
|
|
@@ -400,18 +425,18 @@ class TextToSQLSystem:
|
|
| 400 |
'JobEventsLog': ['事件', '操作', '用戶', 'event', 'log', 'user'],
|
| 401 |
'calendar_days': ['工作日', '假期', 'workday', 'holiday', 'calendar']
|
| 402 |
}
|
| 403 |
-
|
| 404 |
for table, keywords in keyword_to_table.items():
|
| 405 |
if any(keyword in question_lower for keyword in keywords):
|
| 406 |
relevant_tables.append(table)
|
| 407 |
-
|
| 408 |
# 預設重要表格
|
| 409 |
if not relevant_tables:
|
| 410 |
if any(word in question_lower for word in ['客戶', '買家', '申請', '工作單', '數量']):
|
| 411 |
return ['TSR53SampleDescription', 'JobsInProgress']
|
| 412 |
else:
|
| 413 |
return ['JobTimeline', 'TSR53SampleDescription']
|
| 414 |
-
|
| 415 |
return relevant_tables[:3] # 最多返回3個相關表格
|
| 416 |
|
| 417 |
# 請將這整個函數複製到您的 TextToSQLSystem class 內部
|
|
@@ -422,7 +447,7 @@ class TextToSQLSystem:
|
|
| 422 |
"""
|
| 423 |
if not self.schema:
|
| 424 |
return "No schema available.\n"
|
| 425 |
-
|
| 426 |
actual_table_names_map = {name.lower(): name for name in self.schema.keys()}
|
| 427 |
real_table_names = []
|
| 428 |
for table in table_names:
|
|
@@ -453,7 +478,7 @@ class TextToSQLSystem:
|
|
| 453 |
else:
|
| 454 |
cols_str.append(f"{col_name} ({col_type})")
|
| 455 |
formatted += f"Columns: {', '.join(cols_str)}\n\n"
|
| 456 |
-
|
| 457 |
return formatted.strip()
|
| 458 |
|
| 459 |
|
|
@@ -470,14 +495,14 @@ class TextToSQLSystem:
|
|
| 470 |
返回一個元組 (SQL字符串或None, 狀態消息)。
|
| 471 |
"""
|
| 472 |
q_lower = question.lower()
|
| 473 |
-
|
| 474 |
# ==============================================================================
|
| 475 |
# 第一層:高價值意圖識別與模板覆寫 (Intent Recognition & Templating)
|
| 476 |
# ==============================================================================
|
| 477 |
-
|
| 478 |
# --- 預先檢測所有可能的意圖和實體 ---
|
| 479 |
job_no_match = re.search(r"(?:工單|jobno)\s*'\"?([A-Z]{2,3}\d+)'\"?", question, re.IGNORECASE)
|
| 480 |
-
|
| 481 |
entity_match_data = None
|
| 482 |
ENTITY_TO_COLUMN_MAP = {
|
| 483 |
'申請廠商': 'sd.ApplicantName', '申請方': 'sd.ApplicantName', 'applicant': 'sd.ApplicantName',
|
|
@@ -491,7 +516,7 @@ class TextToSQLSystem:
|
|
| 491 |
if match:
|
| 492 |
entity_match_data = {"type": keyword, "name": match.group(1).strip(), "column": column}
|
| 493 |
break
|
| 494 |
-
|
| 495 |
lab_group_match_data = None
|
| 496 |
LAB_GROUP_MAP = {'A':'TA','B':'TB','C':'TC','D':'TD','E':'TE','Y':'TY','TA':'TA','TB':'TB','TC':'TC','TD':'TD','TE':'TE','TY':'TY','WC':'WC','EO':'EO','GCI':'GCI','GCO':'GCO','MI':'MI'}
|
| 497 |
lab_group_match = re.findall(r"([A-Z]+)\s*組", question, re.IGNORECASE)
|
|
@@ -528,7 +553,7 @@ class TextToSQLSystem:
|
|
| 528 |
self._log(f"🔄 檢測到查詢【{entity_type} '{entity_name}' 在 {year} 年的總業績】意圖,啟用模板。", "INFO")
|
| 529 |
template_sql = f"WITH JobTotalAmount AS (SELECT JobNo, SUM(LocalAmount) AS TotalAmount FROM (SELECT DISTINCT JobNo, InvoiceCreditNoteNo, LocalAmount FROM TSR53Invoice) GROUP BY JobNo) SELECT SUM(jta.TotalAmount) AS total_revenue FROM TSR53SampleDescription AS sd JOIN JobTotalAmount AS jta ON sd.JobNo = jta.JobNo WHERE {column_name} LIKE '%{entity_name}%' AND strftime('%Y', sd.FirstReportAuthorizedDate) = '{year}';"
|
| 530 |
return self._finalize_sql(template_sql, f"模板覆寫: 查詢 {entity_type}='{entity_name}' ({year}年) 的總業績")
|
| 531 |
-
|
| 532 |
if not entity_match_data and any(kw in q_lower for kw in ['業績', '營收', '金額', 'sales', 'revenue']):
|
| 533 |
year_match, month_match = re.search(r'(\d{4})\s*年?', question), re.search(r'(\d{1,2})\s*月', question)
|
| 534 |
time_condition, time_log = "", "總"
|
|
@@ -571,17 +596,17 @@ class TextToSQLSystem:
|
|
| 571 |
# 第二層:常規修正流程 (Fallback Corrections)
|
| 572 |
# ==============================================================================
|
| 573 |
self._log("未觸發任何模板,嘗試解析並修正 AI 輸出...", "INFO")
|
| 574 |
-
|
| 575 |
parsed_sql = parse_sql_from_response(raw_response)
|
| 576 |
if not parsed_sql:
|
| 577 |
self._log(f"❌ 未能從模型回應中解析出任何 SQL。原始回應: {raw_response}", "ERROR")
|
| 578 |
return None, f"無法解析SQL。原始回應:\n{raw_response}"
|
| 579 |
|
| 580 |
self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG")
|
| 581 |
-
|
| 582 |
fixed_sql = " " + parsed_sql.strip() + " "
|
| 583 |
fixes_applied_fallback = []
|
| 584 |
-
|
| 585 |
dialect_corrections = {
|
| 586 |
r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)",
|
| 587 |
r"(strftime\('%Y',\s*[^)]+\))\s*=\s*(\d{4})": r"\1 = '\2'",
|
|
@@ -632,47 +657,52 @@ class TextToSQLSystem:
|
|
| 632 |
"""使用 FAISS 快速檢索相似問題"""
|
| 633 |
if self.faiss_index is None or self.dataset is None:
|
| 634 |
return []
|
| 635 |
-
|
| 636 |
try:
|
| 637 |
# 編碼問題
|
| 638 |
q_embedding = self._encode_texts([question]).numpy().astype('float32')
|
| 639 |
-
|
| 640 |
# FAISS 搜索
|
| 641 |
distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset)))
|
| 642 |
-
|
| 643 |
results = []
|
| 644 |
seen_questions = set()
|
| 645 |
-
|
| 646 |
for i, idx in enumerate(indices[0]):
|
| 647 |
if len(results) >= top_k:
|
| 648 |
break
|
| 649 |
-
|
| 650 |
# 修復:將 numpy.int64 轉換為 Python int
|
| 651 |
idx = int(idx) # ← 添加這行轉換
|
| 652 |
-
|
| 653 |
if idx >= len(self.dataset): # 確保索引有效
|
| 654 |
continue
|
| 655 |
-
|
| 656 |
item = self.dataset[idx]
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
# 提取純淨問題
|
| 661 |
clean_q = re.sub(r"以下是一個SQL查詢任務:\s*指令:\s*", "", q_content).strip()
|
| 662 |
if clean_q in seen_questions:
|
| 663 |
continue
|
| 664 |
-
|
| 665 |
seen_questions.add(clean_q)
|
| 666 |
sql = parse_sql_from_response(a_content) or "無法解析範例SQL"
|
| 667 |
-
|
| 668 |
results.append({
|
| 669 |
"similarity": float(distances[0][i]),
|
| 670 |
"question": clean_q,
|
| 671 |
"sql": sql
|
| 672 |
})
|
| 673 |
-
|
| 674 |
return results
|
| 675 |
-
|
| 676 |
except Exception as e:
|
| 677 |
self._log(f"❌ 檢索失敗: {e}", "ERROR")
|
| 678 |
return []
|
|
@@ -684,7 +714,7 @@ class TextToSQLSystem:
|
|
| 684 |
建立一個高度結構化、以任務為導向的提示詞,使用清晰的標題分隔符。
|
| 685 |
"""
|
| 686 |
relevant_tables = self._identify_relevant_tables(user_q)
|
| 687 |
-
|
| 688 |
# 使用我們新的、更簡單的 schema 格式化函數
|
| 689 |
schema_str = self._format_relevant_schema(relevant_tables)
|
| 690 |
|
|
@@ -721,7 +751,7 @@ SQL:
|
|
| 721 |
def _generate_fallback_sql(self, prompt: str) -> str:
|
| 722 |
"""當模型不可用時的備用 SQL 生成"""
|
| 723 |
prompt_lower = prompt.lower()
|
| 724 |
-
|
| 725 |
# 簡單的關鍵詞匹配生成基本 SQL
|
| 726 |
if "統計" in prompt or "數量" in prompt or "多少" in prompt:
|
| 727 |
if "月" in prompt:
|
|
@@ -730,13 +760,13 @@ SQL:
|
|
| 730 |
return "SELECT applicant, COUNT(*) as count FROM tsr53sampledescription GROUP BY applicant ORDER BY count DESC;"
|
| 731 |
else:
|
| 732 |
return "SELECT COUNT(*) as total_count FROM jobtimeline WHERE completed_time IS NOT NULL;"
|
| 733 |
-
|
| 734 |
elif "金額" in prompt or "總額" in prompt:
|
| 735 |
return "SELECT SUM(amount) as total_amount FROM tsr53invoice;"
|
| 736 |
-
|
| 737 |
elif "評級" in prompt or "pass" in prompt_lower or "fail" in prompt_lower:
|
| 738 |
return "SELECT rating, COUNT(*) as count FROM tsr53sampledescription GROUP BY rating;"
|
| 739 |
-
|
| 740 |
else:
|
| 741 |
return "SELECT * FROM jobtimeline LIMIT 10;"
|
| 742 |
|
|
@@ -745,22 +775,22 @@ SQL:
|
|
| 745 |
try:
|
| 746 |
if not os.path.exists(model_path):
|
| 747 |
return False
|
| 748 |
-
|
| 749 |
# 檢查檔案大小(至少應該有幾MB)
|
| 750 |
file_size = os.path.getsize(model_path)
|
| 751 |
if file_size < 10 * 1024 * 1024: # 小於 10MB 可能有問題
|
| 752 |
return False
|
| 753 |
-
|
| 754 |
# 檢查 GGUF 檔案頭部
|
| 755 |
with open(model_path, 'rb') as f:
|
| 756 |
header = f.read(8)
|
| 757 |
if not header.startswith(b'GGUF'):
|
| 758 |
return False
|
| 759 |
-
|
| 760 |
return True
|
| 761 |
except Exception:
|
| 762 |
return False
|
| 763 |
-
|
| 764 |
# in class TextToSQLSystem:
|
| 765 |
|
| 766 |
def process_question(self, question: str) -> Tuple[str, str]:
|
|
@@ -769,7 +799,7 @@ SQL:
|
|
| 769 |
if question in self.query_cache:
|
| 770 |
self._log("⚡ 使用緩存結果")
|
| 771 |
return self.query_cache[question]
|
| 772 |
-
|
| 773 |
self.log_history = []
|
| 774 |
self._log(f"⏰ 處理問題: {question}")
|
| 775 |
|
|
@@ -788,12 +818,12 @@ SQL:
|
|
| 788 |
|
| 789 |
# 4. **新的核心步驟**: 呼叫決策引擎來生成最終 SQL
|
| 790 |
final_sql, status_message = self._validate_and_fix_sql(question, response)
|
| 791 |
-
|
| 792 |
if final_sql:
|
| 793 |
result = (final_sql, status_message)
|
| 794 |
else:
|
| 795 |
result = (status_message, "生成失敗")
|
| 796 |
-
|
| 797 |
# 緩存結果
|
| 798 |
self.query_cache[question] = result
|
| 799 |
return result
|
|
@@ -804,10 +834,10 @@ text_to_sql_system = TextToSQLSystem()
|
|
| 804 |
def process_query(q: str):
|
| 805 |
if not q.strip():
|
| 806 |
return "", "等待輸入", "請輸入問題"
|
| 807 |
-
|
| 808 |
sql, status = text_to_sql_system.process_question(q)
|
| 809 |
logs = "\n".join(text_to_sql_system.log_history[-10:]) # 只顯示最後10條日誌
|
| 810 |
-
|
| 811 |
return sql, status, logs
|
| 812 |
|
| 813 |
# 範例問題
|
|
@@ -822,19 +852,19 @@ examples = [
|
|
| 822 |
with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
|
| 823 |
gr.Markdown("# ⚡ Text-to-SQL 智能助手")
|
| 824 |
gr.Markdown("輸入自然語言問題,自動生成SQL查詢語句")
|
| 825 |
-
|
| 826 |
with gr.Row():
|
| 827 |
with gr.Column(scale=2):
|
| 828 |
inp = gr.Textbox(lines=3, label="💬 您的問題", placeholder="例如:2024年每月完成多少份報告?")
|
| 829 |
btn = gr.Button("🚀 生成 SQL", variant="primary")
|
| 830 |
status = gr.Textbox(label="狀態", interactive=False)
|
| 831 |
-
|
| 832 |
with gr.Column(scale=3):
|
| 833 |
sql_out = gr.Code(label="🤖 生成的 SQL", language="sql", lines=8)
|
| 834 |
-
|
| 835 |
with gr.Accordion("📋 處理日誌", open=False):
|
| 836 |
logs = gr.Textbox(lines=8, label="日誌", interactive=False)
|
| 837 |
-
|
| 838 |
# 範例區
|
| 839 |
gr.Examples(
|
| 840 |
examples=examples,
|
|
|
|
| 46 |
"""從模型輸出提取 SQL,增強版"""
|
| 47 |
if not response_text:
|
| 48 |
return None
|
| 49 |
+
|
| 50 |
# 清理回應文本
|
| 51 |
response_text = response_text.strip()
|
| 52 |
+
|
| 53 |
# 1. 先找 ```sql ... ```
|
| 54 |
match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
|
| 55 |
if match:
|
| 56 |
return match.group(1).strip()
|
| 57 |
+
|
| 58 |
# 2. 找任何 ``` 包圍的內容
|
| 59 |
match = re.search(r"```\s*\n?(.*?)\n?```", response_text, re.DOTALL)
|
| 60 |
if match:
|
| 61 |
sql_candidate = match.group(1).strip()
|
| 62 |
if sql_candidate.upper().startswith('SELECT'):
|
| 63 |
return sql_candidate
|
| 64 |
+
|
| 65 |
# 3. 找 SQL 語句(更寬鬆的匹配)
|
| 66 |
match = re.search(r"(SELECT\s+.*?;)", response_text, re.DOTALL | re.IGNORECASE)
|
| 67 |
if match:
|
| 68 |
return match.group(1).strip()
|
| 69 |
+
|
| 70 |
# 4. 找沒有分號的 SQL
|
| 71 |
match = re.search(r"(SELECT\s+.*?)(?=\n\n|\n```|$|\n[^,\s])", response_text, re.DOTALL | re.IGNORECASE)
|
| 72 |
if match:
|
|
|
|
| 74 |
if not sql.endswith(';'):
|
| 75 |
sql += ';'
|
| 76 |
return sql
|
| 77 |
+
|
| 78 |
# 5. 如果包含 SELECT,嘗試提取整行
|
| 79 |
if 'SELECT' in response_text.upper():
|
| 80 |
lines = response_text.split('\n')
|
|
|
|
| 84 |
if not line.endswith(';'):
|
| 85 |
line += ';'
|
| 86 |
return line
|
| 87 |
+
|
| 88 |
return None
|
| 89 |
|
| 90 |
# ==================== Text-to-SQL 核心類 ====================
|
|
|
|
| 113 |
self._log("✅ 系統初始化完成")
|
| 114 |
# 載入數據庫結構
|
| 115 |
self.schema = self._load_schema()
|
| 116 |
+
|
| 117 |
# 暫時添加:打印 schema 信息
|
| 118 |
if self.schema:
|
| 119 |
print("=" * 50)
|
|
|
|
| 125 |
for col in columns[:5]: # 只顯示前5個
|
| 126 |
print(f" - {col['name']} ({col['type']})")
|
| 127 |
print("=" * 50)
|
| 128 |
+
|
| 129 |
# in class TextToSQLSystem:
|
| 130 |
|
| 131 |
def _load_gguf_model(self):
|
|
|
|
| 137 |
filename=GGUF_FILENAME,
|
| 138 |
repo_type="dataset"
|
| 139 |
)
|
| 140 |
+
|
| 141 |
# 使用一組更基礎、更穩定的參數來載入模型
|
| 142 |
self.llm = Llama(
|
| 143 |
model_path=model_path,
|
|
|
|
| 147 |
verbose=False, # 設為 False 避免 llama.cpp 本身的日誌干擾
|
| 148 |
n_gpu_layers=0 # 確認在 CPU 上運行
|
| 149 |
)
|
| 150 |
+
|
| 151 |
# 簡單測試模型是否能回應
|
| 152 |
self.llm("你好", max_tokens=3)
|
| 153 |
self._log("✅ GGUF 模型載入成功")
|
| 154 |
+
|
| 155 |
except Exception as e:
|
| 156 |
self._log(f"❌ GGUF 載入失敗: {e}", "ERROR")
|
| 157 |
self._log("系統將無法生成 SQL。請檢查模型檔案或 llama-cpp-python 安裝。", "CRITICAL")
|
| 158 |
self.llm = None
|
| 159 |
+
|
| 160 |
def _try_gguf_loading(self):
|
| 161 |
"""嘗試載入 GGUF"""
|
| 162 |
try:
|
|
|
|
| 165 |
filename=GGUF_FILENAME,
|
| 166 |
repo_type="dataset"
|
| 167 |
)
|
| 168 |
+
|
| 169 |
self.llm = Llama(
|
| 170 |
model_path=model_path,
|
| 171 |
n_ctx=512,
|
|
|
|
| 173 |
verbose=False,
|
| 174 |
n_gpu_layers=0
|
| 175 |
)
|
| 176 |
+
|
| 177 |
# 測試生成
|
| 178 |
test_result = self.llm("SELECT", max_tokens=5)
|
| 179 |
self._log("✅ GGUF 模型載入成功")
|
| 180 |
return True
|
| 181 |
+
|
| 182 |
except Exception as e:
|
| 183 |
self._log(f"GGUF 載入失敗: {e}", "WARNING")
|
| 184 |
return False
|
| 185 |
+
|
| 186 |
def _load_transformers_model(self):
|
| 187 |
"""使用 Transformers 載入你的微調模型"""
|
| 188 |
try:
|
| 189 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 190 |
import torch
|
| 191 |
+
|
| 192 |
self._log(f"載入 Transformers 模型: {FINETUNED_MODEL_PATH}")
|
| 193 |
+
|
| 194 |
# 載入你的微調模型
|
| 195 |
self.transformers_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL_PATH)
|
| 196 |
self.transformers_model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 199 |
device_map="cpu", # 強制使用 CPU
|
| 200 |
trust_remote_code=True # Qwen 模型可能需要
|
| 201 |
)
|
| 202 |
+
|
| 203 |
# 創建生成管道
|
| 204 |
self.generation_pipeline = pipeline(
|
| 205 |
"text-generation",
|
|
|
|
| 212 |
top_p=0.9,
|
| 213 |
pad_token_id=self.transformers_tokenizer.eos_token_id
|
| 214 |
)
|
| 215 |
+
|
| 216 |
self.llm = "transformers" # 標記使用 transformers
|
| 217 |
self._log("✅ Transformers 模型��入成功")
|
| 218 |
+
|
| 219 |
except Exception as e:
|
| 220 |
self._log(f"❌ Transformers 載入也失敗: {e}", "ERROR")
|
| 221 |
self.llm = None
|
| 222 |
+
|
| 223 |
def huggingface_api_call(self, prompt: str) -> str:
|
| 224 |
"""調用 GGUF 模型,並加入詳細的原始輸出日誌"""
|
| 225 |
if self.llm is None:
|
| 226 |
self._log("模型未載入,返回 fallback SQL。", "ERROR")
|
| 227 |
return self._generate_fallback_sql(prompt)
|
| 228 |
+
|
| 229 |
try:
|
| 230 |
output = self.llm(
|
| 231 |
prompt,
|
|
|
|
| 236 |
# --- 將 stop 參數加回來 ---
|
| 237 |
stop=["```", ";", "\n\n", "</s>"],
|
| 238 |
)
|
| 239 |
+
|
| 240 |
self._log(f"🧠 模型原始輸出 (Raw Output): {output}", "DEBUG")
|
| 241 |
+
|
| 242 |
if output and "choices" in output and len(output["choices"]) > 0:
|
| 243 |
generated_text = output["choices"][0]["text"]
|
| 244 |
self._log(f"📝 提取出的生成文本: {generated_text.strip()}", "DEBUG")
|
|
|
|
| 246 |
else:
|
| 247 |
self._log("❌ 模型的原始輸出格式不正確或為空。", "ERROR")
|
| 248 |
return ""
|
| 249 |
+
|
| 250 |
except Exception as e:
|
| 251 |
self._log(f"❌ 模型生成過程中發生嚴重錯誤: {e}", "CRITICAL")
|
| 252 |
import traceback
|
| 253 |
self._log(traceback.format_exc(), "DEBUG")
|
| 254 |
return ""
|
| 255 |
+
|
| 256 |
def _load_gguf_model_fallback(self, model_path):
|
| 257 |
"""備用載入方式"""
|
| 258 |
try:
|
|
|
|
| 286 |
)
|
| 287 |
with open(schema_path, "r", encoding="utf-8") as f:
|
| 288 |
schema_data = json.load(f)
|
| 289 |
+
|
| 290 |
# 添加調試信息
|
| 291 |
self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格:")
|
| 292 |
for table_name, columns in schema_data.items():
|
|
|
|
| 294 |
# 顯示前3個欄位作為範例
|
| 295 |
sample_cols = [col['name'] for col in columns[:3]]
|
| 296 |
self._log(f" 範例欄位: {', '.join(sample_cols)}")
|
| 297 |
+
|
| 298 |
self._log("✅ 數據庫結構載入完成")
|
| 299 |
return schema_data
|
| 300 |
+
|
| 301 |
except Exception as e:
|
| 302 |
self._log(f"❌ 載入 schema 失敗: {e}", "ERROR")
|
| 303 |
return {}
|
| 304 |
+
|
| 305 |
# 也可以添加一個方法來檢查生成的 SQL 是否使用了正確的表格和欄位
|
| 306 |
def _analyze_sql_correctness(self, sql: str) -> Dict:
|
| 307 |
"""分析 SQL 的正確性"""
|
|
|
|
| 312 |
'invalid_columns': [],
|
| 313 |
'suggestions': []
|
| 314 |
}
|
| 315 |
+
|
| 316 |
if not self.schema:
|
| 317 |
return analysis
|
| 318 |
+
|
| 319 |
# 提取 SQL 中的表格名稱
|
| 320 |
table_pattern = r'FROM\s+(\w+)|JOIN\s+(\w+)'
|
| 321 |
table_matches = re.findall(table_pattern, sql, re.IGNORECASE)
|
| 322 |
used_tables = [match[0] or match[1] for match in table_matches]
|
| 323 |
+
|
| 324 |
# 檢查表格是否存在
|
| 325 |
valid_tables = list(self.schema.keys())
|
| 326 |
for table in used_tables:
|
|
|
|
| 332 |
for valid_table in valid_tables:
|
| 333 |
if table.lower() in valid_table.lower() or valid_table.lower() in table.lower():
|
| 334 |
analysis['suggestions'].append(f"{table} -> {valid_table}")
|
| 335 |
+
|
| 336 |
# 提取欄位名稱(簡單版本)
|
| 337 |
column_pattern = r'SELECT\s+(.*?)\s+FROM|WHERE\s+(\w+)\s*[=<>]|GROUP BY\s+(\w+)|ORDER BY\s+(\w+)'
|
| 338 |
column_matches = re.findall(column_pattern, sql, re.IGNORECASE)
|
| 339 |
+
|
| 340 |
return analysis
|
| 341 |
|
| 342 |
def _encode_texts(self, texts):
|
| 343 |
"""編碼文本為嵌入向量"""
|
| 344 |
if isinstance(texts, str):
|
| 345 |
texts = [texts]
|
| 346 |
+
|
| 347 |
+
inputs = self.embed_tokenizer(texts, padding=True, truncation=True,
|
| 348 |
return_tensors="pt", max_length=512)
|
| 349 |
if DEVICE == "cuda":
|
| 350 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 351 |
+
|
| 352 |
with torch.no_grad():
|
| 353 |
outputs = self.embed_model(**inputs)
|
| 354 |
+
|
| 355 |
# 使用平均池化
|
| 356 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 357 |
return embeddings.cpu()
|
|
|
|
| 360 |
"""載入數據集並建立 FAISS 索引"""
|
| 361 |
try:
|
| 362 |
dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
|
| 363 |
+
|
| 364 |
+
# 先過濾不完整樣本,避免 messages 長度不足導致索引或檢索報錯
|
| 365 |
+
try:
|
| 366 |
+
original_count = len(dataset)
|
| 367 |
+
except Exception:
|
| 368 |
+
original_count = None
|
| 369 |
+
|
| 370 |
+
dataset = dataset.filter(
|
| 371 |
+
lambda ex: isinstance(ex.get("messages"), list)
|
| 372 |
+
and len(ex["messages"]) >= 2
|
| 373 |
+
and all(
|
| 374 |
+
isinstance(m.get("content"), str) and m.get("content") and m["content"].strip()
|
| 375 |
+
for m in ex["messages"][:2]
|
| 376 |
+
)
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
if original_count is not None:
|
| 380 |
+
self._log(
|
| 381 |
+
f"資料集清理: 原始 {original_count} 筆, 過濾後 {len(dataset)} 筆, 移除 {original_count - len(dataset)} 筆"
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
if len(dataset) == 0:
|
| 385 |
+
self._log("清理後資料集為空,無法建立索引。", "ERROR")
|
| 386 |
+
return None, None
|
| 387 |
+
|
| 388 |
corpus = [item['messages'][0]['content'] for item in dataset]
|
| 389 |
self._log(f"正在編碼 {len(corpus)} 個問題...")
|
| 390 |
+
|
| 391 |
# 批量編碼
|
| 392 |
embeddings_list = []
|
| 393 |
batch_size = 32
|
| 394 |
+
|
| 395 |
for i in range(0, len(corpus), batch_size):
|
| 396 |
batch_texts = corpus[i:i+batch_size]
|
| 397 |
batch_embeddings = self._encode_texts(batch_texts)
|
| 398 |
embeddings_list.append(batch_embeddings)
|
| 399 |
self._log(f"已編碼 {min(i+batch_size, len(corpus))}/{len(corpus)}")
|
| 400 |
+
|
| 401 |
all_embeddings = torch.cat(embeddings_list, dim=0).numpy()
|
| 402 |
+
|
| 403 |
# 建立 FAISS 索引
|
| 404 |
index = faiss.IndexFlatIP(all_embeddings.shape[1])
|
| 405 |
index.add(all_embeddings.astype('float32'))
|
| 406 |
+
|
| 407 |
self._log("✅ 向量索引建立完成")
|
| 408 |
return dataset, index
|
| 409 |
+
|
| 410 |
except Exception as e:
|
| 411 |
self._log(f"❌ 載入數據失敗: {e}", "ERROR")
|
| 412 |
return None, None
|
|
|
|
| 415 |
"""根據實際 Schema 識別相關表格"""
|
| 416 |
question_lower = question.lower()
|
| 417 |
relevant_tables = []
|
| 418 |
+
|
| 419 |
# 根據實際表格的關鍵詞映射
|
| 420 |
keyword_to_table = {
|
| 421 |
'TSR53SampleDescription': ['客戶', '買方', '申請', '發票對象', 'customer', 'invoice', 'sample'],
|
|
|
|
| 425 |
'JobEventsLog': ['事件', '操作', '用戶', 'event', 'log', 'user'],
|
| 426 |
'calendar_days': ['工作日', '假期', 'workday', 'holiday', 'calendar']
|
| 427 |
}
|
| 428 |
+
|
| 429 |
for table, keywords in keyword_to_table.items():
|
| 430 |
if any(keyword in question_lower for keyword in keywords):
|
| 431 |
relevant_tables.append(table)
|
| 432 |
+
|
| 433 |
# 預設重要表格
|
| 434 |
if not relevant_tables:
|
| 435 |
if any(word in question_lower for word in ['客戶', '買家', '申請', '工作單', '數量']):
|
| 436 |
return ['TSR53SampleDescription', 'JobsInProgress']
|
| 437 |
else:
|
| 438 |
return ['JobTimeline', 'TSR53SampleDescription']
|
| 439 |
+
|
| 440 |
return relevant_tables[:3] # 最多返回3個相關表格
|
| 441 |
|
| 442 |
# 請將這整個函數複製到您的 TextToSQLSystem class 內部
|
|
|
|
| 447 |
"""
|
| 448 |
if not self.schema:
|
| 449 |
return "No schema available.\n"
|
| 450 |
+
|
| 451 |
actual_table_names_map = {name.lower(): name for name in self.schema.keys()}
|
| 452 |
real_table_names = []
|
| 453 |
for table in table_names:
|
|
|
|
| 478 |
else:
|
| 479 |
cols_str.append(f"{col_name} ({col_type})")
|
| 480 |
formatted += f"Columns: {', '.join(cols_str)}\n\n"
|
| 481 |
+
|
| 482 |
return formatted.strip()
|
| 483 |
|
| 484 |
|
|
|
|
| 495 |
返回一個元組 (SQL字符串或None, 狀態消息)。
|
| 496 |
"""
|
| 497 |
q_lower = question.lower()
|
| 498 |
+
|
| 499 |
# ==============================================================================
|
| 500 |
# 第一層:高價值意圖識別與模板覆寫 (Intent Recognition & Templating)
|
| 501 |
# ==============================================================================
|
| 502 |
+
|
| 503 |
# --- 預先檢測所有可能的意圖和實體 ---
|
| 504 |
job_no_match = re.search(r"(?:工單|jobno)\s*'\"?([A-Z]{2,3}\d+)'\"?", question, re.IGNORECASE)
|
| 505 |
+
|
| 506 |
entity_match_data = None
|
| 507 |
ENTITY_TO_COLUMN_MAP = {
|
| 508 |
'申請廠商': 'sd.ApplicantName', '申請方': 'sd.ApplicantName', 'applicant': 'sd.ApplicantName',
|
|
|
|
| 516 |
if match:
|
| 517 |
entity_match_data = {"type": keyword, "name": match.group(1).strip(), "column": column}
|
| 518 |
break
|
| 519 |
+
|
| 520 |
lab_group_match_data = None
|
| 521 |
LAB_GROUP_MAP = {'A':'TA','B':'TB','C':'TC','D':'TD','E':'TE','Y':'TY','TA':'TA','TB':'TB','TC':'TC','TD':'TD','TE':'TE','TY':'TY','WC':'WC','EO':'EO','GCI':'GCI','GCO':'GCO','MI':'MI'}
|
| 522 |
lab_group_match = re.findall(r"([A-Z]+)\s*組", question, re.IGNORECASE)
|
|
|
|
| 553 |
self._log(f"🔄 檢測到查詢【{entity_type} '{entity_name}' 在 {year} 年的總業績】意圖,啟用模板。", "INFO")
|
| 554 |
template_sql = f"WITH JobTotalAmount AS (SELECT JobNo, SUM(LocalAmount) AS TotalAmount FROM (SELECT DISTINCT JobNo, InvoiceCreditNoteNo, LocalAmount FROM TSR53Invoice) GROUP BY JobNo) SELECT SUM(jta.TotalAmount) AS total_revenue FROM TSR53SampleDescription AS sd JOIN JobTotalAmount AS jta ON sd.JobNo = jta.JobNo WHERE {column_name} LIKE '%{entity_name}%' AND strftime('%Y', sd.FirstReportAuthorizedDate) = '{year}';"
|
| 555 |
return self._finalize_sql(template_sql, f"模板覆寫: 查詢 {entity_type}='{entity_name}' ({year}年) 的總業績")
|
| 556 |
+
|
| 557 |
if not entity_match_data and any(kw in q_lower for kw in ['業績', '營收', '金額', 'sales', 'revenue']):
|
| 558 |
year_match, month_match = re.search(r'(\d{4})\s*年?', question), re.search(r'(\d{1,2})\s*月', question)
|
| 559 |
time_condition, time_log = "", "總"
|
|
|
|
| 596 |
# 第二層:常規修正流程 (Fallback Corrections)
|
| 597 |
# ==============================================================================
|
| 598 |
self._log("未觸發任何模板,嘗試解析並修正 AI 輸出...", "INFO")
|
| 599 |
+
|
| 600 |
parsed_sql = parse_sql_from_response(raw_response)
|
| 601 |
if not parsed_sql:
|
| 602 |
self._log(f"❌ 未能從模型回應中解析出任何 SQL。原始回應: {raw_response}", "ERROR")
|
| 603 |
return None, f"無法解析SQL。原始回應:\n{raw_response}"
|
| 604 |
|
| 605 |
self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG")
|
| 606 |
+
|
| 607 |
fixed_sql = " " + parsed_sql.strip() + " "
|
| 608 |
fixes_applied_fallback = []
|
| 609 |
+
|
| 610 |
dialect_corrections = {
|
| 611 |
r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)",
|
| 612 |
r"(strftime\('%Y',\s*[^)]+\))\s*=\s*(\d{4})": r"\1 = '\2'",
|
|
|
|
| 657 |
"""使用 FAISS 快速檢索相似問題"""
|
| 658 |
if self.faiss_index is None or self.dataset is None:
|
| 659 |
return []
|
| 660 |
+
|
| 661 |
try:
|
| 662 |
# 編碼問題
|
| 663 |
q_embedding = self._encode_texts([question]).numpy().astype('float32')
|
| 664 |
+
|
| 665 |
# FAISS 搜索
|
| 666 |
distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset)))
|
| 667 |
+
|
| 668 |
results = []
|
| 669 |
seen_questions = set()
|
| 670 |
+
|
| 671 |
for i, idx in enumerate(indices[0]):
|
| 672 |
if len(results) >= top_k:
|
| 673 |
break
|
| 674 |
+
|
| 675 |
# 修復:將 numpy.int64 轉換為 Python int
|
| 676 |
idx = int(idx) # ← 添加這行轉換
|
| 677 |
+
|
| 678 |
if idx >= len(self.dataset): # 確保索引有效
|
| 679 |
continue
|
| 680 |
+
|
| 681 |
item = self.dataset[idx]
|
| 682 |
+
# 防呆:若樣本不完整則跳過
|
| 683 |
+
if not isinstance(item.get('messages'), list) or len(item['messages']) < 2:
|
| 684 |
+
continue
|
| 685 |
+
q_content = (item['messages'][0].get('content') or '').strip()
|
| 686 |
+
a_content = (item['messages'][1].get('content') or '').strip()
|
| 687 |
+
if not q_content or not a_content:
|
| 688 |
+
continue
|
| 689 |
+
|
| 690 |
# 提取純淨問題
|
| 691 |
clean_q = re.sub(r"以下是一個SQL查詢任務:\s*指令:\s*", "", q_content).strip()
|
| 692 |
if clean_q in seen_questions:
|
| 693 |
continue
|
| 694 |
+
|
| 695 |
seen_questions.add(clean_q)
|
| 696 |
sql = parse_sql_from_response(a_content) or "無法解析範例SQL"
|
| 697 |
+
|
| 698 |
results.append({
|
| 699 |
"similarity": float(distances[0][i]),
|
| 700 |
"question": clean_q,
|
| 701 |
"sql": sql
|
| 702 |
})
|
| 703 |
+
|
| 704 |
return results
|
| 705 |
+
|
| 706 |
except Exception as e:
|
| 707 |
self._log(f"❌ 檢索失敗: {e}", "ERROR")
|
| 708 |
return []
|
|
|
|
| 714 |
建立一個高度結構化、以任務為導向的提示詞,使用清晰的標題分隔符。
|
| 715 |
"""
|
| 716 |
relevant_tables = self._identify_relevant_tables(user_q)
|
| 717 |
+
|
| 718 |
# 使用我們新的、更簡單的 schema 格式化函數
|
| 719 |
schema_str = self._format_relevant_schema(relevant_tables)
|
| 720 |
|
|
|
|
| 751 |
def _generate_fallback_sql(self, prompt: str) -> str:
|
| 752 |
"""當模型不可用時的備用 SQL 生成"""
|
| 753 |
prompt_lower = prompt.lower()
|
| 754 |
+
|
| 755 |
# 簡單的關鍵詞匹配生成基本 SQL
|
| 756 |
if "統計" in prompt or "數量" in prompt or "多少" in prompt:
|
| 757 |
if "月" in prompt:
|
|
|
|
| 760 |
return "SELECT applicant, COUNT(*) as count FROM tsr53sampledescription GROUP BY applicant ORDER BY count DESC;"
|
| 761 |
else:
|
| 762 |
return "SELECT COUNT(*) as total_count FROM jobtimeline WHERE completed_time IS NOT NULL;"
|
| 763 |
+
|
| 764 |
elif "金額" in prompt or "總額" in prompt:
|
| 765 |
return "SELECT SUM(amount) as total_amount FROM tsr53invoice;"
|
| 766 |
+
|
| 767 |
elif "評級" in prompt or "pass" in prompt_lower or "fail" in prompt_lower:
|
| 768 |
return "SELECT rating, COUNT(*) as count FROM tsr53sampledescription GROUP BY rating;"
|
| 769 |
+
|
| 770 |
else:
|
| 771 |
return "SELECT * FROM jobtimeline LIMIT 10;"
|
| 772 |
|
|
|
|
| 775 |
try:
|
| 776 |
if not os.path.exists(model_path):
|
| 777 |
return False
|
| 778 |
+
|
| 779 |
# 檢查檔案大小(至少應該有幾MB)
|
| 780 |
file_size = os.path.getsize(model_path)
|
| 781 |
if file_size < 10 * 1024 * 1024: # 小於 10MB 可能有問題
|
| 782 |
return False
|
| 783 |
+
|
| 784 |
# 檢查 GGUF 檔案頭部
|
| 785 |
with open(model_path, 'rb') as f:
|
| 786 |
header = f.read(8)
|
| 787 |
if not header.startswith(b'GGUF'):
|
| 788 |
return False
|
| 789 |
+
|
| 790 |
return True
|
| 791 |
except Exception:
|
| 792 |
return False
|
| 793 |
+
|
| 794 |
# in class TextToSQLSystem:
|
| 795 |
|
| 796 |
def process_question(self, question: str) -> Tuple[str, str]:
|
|
|
|
| 799 |
if question in self.query_cache:
|
| 800 |
self._log("⚡ 使用緩存結果")
|
| 801 |
return self.query_cache[question]
|
| 802 |
+
|
| 803 |
self.log_history = []
|
| 804 |
self._log(f"⏰ 處理問題: {question}")
|
| 805 |
|
|
|
|
| 818 |
|
| 819 |
# 4. **新的核心步驟**: 呼叫決策引擎來生成最終 SQL
|
| 820 |
final_sql, status_message = self._validate_and_fix_sql(question, response)
|
| 821 |
+
|
| 822 |
if final_sql:
|
| 823 |
result = (final_sql, status_message)
|
| 824 |
else:
|
| 825 |
result = (status_message, "生成失敗")
|
| 826 |
+
|
| 827 |
# 緩存結果
|
| 828 |
self.query_cache[question] = result
|
| 829 |
return result
|
|
|
|
| 834 |
def process_query(q: str):
|
| 835 |
if not q.strip():
|
| 836 |
return "", "等待輸入", "請輸入問題"
|
| 837 |
+
|
| 838 |
sql, status = text_to_sql_system.process_question(q)
|
| 839 |
logs = "\n".join(text_to_sql_system.log_history[-10:]) # 只顯示最後10條日誌
|
| 840 |
+
|
| 841 |
return sql, status, logs
|
| 842 |
|
| 843 |
# 範例問題
|
|
|
|
| 852 |
with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
|
| 853 |
gr.Markdown("# ⚡ Text-to-SQL 智能助手")
|
| 854 |
gr.Markdown("輸入自然語言問題,自動生成SQL查詢語句")
|
| 855 |
+
|
| 856 |
with gr.Row():
|
| 857 |
with gr.Column(scale=2):
|
| 858 |
inp = gr.Textbox(lines=3, label="💬 您的問題", placeholder="例如:2024年每月完成多少份報告?")
|
| 859 |
btn = gr.Button("🚀 生成 SQL", variant="primary")
|
| 860 |
status = gr.Textbox(label="狀態", interactive=False)
|
| 861 |
+
|
| 862 |
with gr.Column(scale=3):
|
| 863 |
sql_out = gr.Code(label="🤖 生成的 SQL", language="sql", lines=8)
|
| 864 |
+
|
| 865 |
with gr.Accordion("📋 處理日誌", open=False):
|
| 866 |
logs = gr.Textbox(lines=8, label="日誌", interactive=False)
|
| 867 |
+
|
| 868 |
# 範例區
|
| 869 |
gr.Examples(
|
| 870 |
examples=examples,
|