Paul720810 commited on
Commit
2251faa
·
verified ·
1 Parent(s): 230e2a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -649
app.py CHANGED
@@ -1,39 +1,44 @@
 
 
 
 
1
  import gradio as gr
2
  import os
3
  import re
4
  import json
5
  import torch
6
  import numpy as np
 
 
7
  from datetime import datetime
8
  from datasets import load_dataset
9
  from huggingface_hub import hf_hub_download
10
  from llama_cpp import Llama
11
  from typing import List, Dict, Tuple, Optional
12
  import faiss
13
- from functools import lru_cache
14
 
15
- # 使用 transformers 替代 sentence-transformers
16
  from transformers import AutoModel, AutoTokenizer
17
  import torch.nn.functional as F
18
 
19
- # ==================== 配置區 ====================
20
- DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
21
- GGUF_REPO_ID = "Paul720810/gguf-models"
22
- #GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q4_k_m.gguf"
23
  GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q8_0.gguf"
 
24
 
25
- # 添加這一行:你的原始微調模型路徑
26
- FINETUNED_MODEL_PATH = "Paul720810/qwen2.5-coder-1.5b-sql-finetuned" # ← 新增這行
27
-
28
  FEW_SHOT_EXAMPLES_COUNT = 1
29
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
31
 
 
 
 
32
  print("=" * 60)
33
- print("🤖 Text-to-SQL 系統啟動中...")
34
- print(f"📊 數據集: {DATASET_REPO_ID}")
35
- print(f"🤖 嵌入模型: {EMBED_MODEL_NAME}")
36
- print(f"💻 設備: {DEVICE}")
37
  print("=" * 60)
38
 
39
  # ==================== 工具函數 ====================
@@ -41,51 +46,32 @@ def get_current_time():
41
  return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
42
 
43
  def format_log(message: str, level: str = "INFO") -> str:
44
- return f"[{get_current_time()}] [{level.upper()}] {message}"
 
 
45
 
46
  def parse_sql_from_response(response_text: str) -> Optional[str]:
47
- """從模型輸出提取 SQL,增強版"""
48
- if not response_text:
49
- return None
50
-
51
- # 清理回應文本
52
  response_text = response_text.strip()
53
-
54
- # 1. 先找 ```sql ... ```
55
  match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
56
- if match:
57
- return match.group(1).strip()
58
-
59
- # 2. 找任何 ``` 包圍的內容
60
  match = re.search(r"```\s*\n?(.*?)\n?```", response_text, re.DOTALL)
61
  if match:
62
  sql_candidate = match.group(1).strip()
63
- if sql_candidate.upper().startswith('SELECT'):
64
- return sql_candidate
65
-
66
- # 3. 找 SQL 語句(更寬鬆的匹配)
67
  match = re.search(r"(SELECT\s+.*?;)", response_text, re.DOTALL | re.IGNORECASE)
68
- if match:
69
- return match.group(1).strip()
70
-
71
- # 4. 找沒有分號的 SQL
72
  match = re.search(r"(SELECT\s+.*?)(?=\n\n|\n```|$|\n[^,\s])", response_text, re.DOTALL | re.IGNORECASE)
73
  if match:
74
  sql = match.group(1).strip()
75
- if not sql.endswith(';'):
76
- sql += ';'
77
  return sql
78
-
79
- # 5. 如果包含 SELECT,嘗試提取整行
80
  if 'SELECT' in response_text.upper():
81
- lines = response_text.split('\n')
82
- for line in lines:
83
  line = line.strip()
84
  if line.upper().startswith('SELECT'):
85
- if not line.endswith(';'):
86
- line += ';'
87
  return line
88
-
89
  return None
90
 
91
  # ==================== Text-to-SQL 核心類 ====================
@@ -94,445 +80,179 @@ class TextToSQLSystem:
94
  self.log_history = []
95
  self._log("初始化系統...")
96
  self.query_cache = {}
97
-
98
- # 1. 載入嵌入模型
99
- self._log(f"載入嵌入模型: {embed_model_name}")
100
- self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
101
- self.embed_model = AutoModel.from_pretrained(embed_model_name)
102
- if DEVICE == "cuda":
103
- self.embed_model = self.embed_model.cuda()
104
-
105
- # 2. 載入數據庫結構
106
- self.schema = self._load_schema()
107
-
108
- # 3. 載入數據集並建立索引
109
- self.dataset, self.faiss_index = self._load_and_index_dataset()
110
-
111
- # 4. 載入 GGUF 模型(添加錯誤處理)
112
- self._load_gguf_model()
113
-
114
- self._log("✅ 系統初始化完成")
115
- # 載入數據庫結構
116
- self.schema = self._load_schema()
117
-
118
- # 暫時添加:打印 schema 信息
119
- if self.schema:
120
- print("=" * 50)
121
- print("數據庫 Schema 信息:")
122
- for table_name, columns in self.schema.items():
123
- print(f"\n表格: {table_name}")
124
- print(f"欄位數: {len(columns)}")
125
- print("欄位列表:")
126
- for col in columns[:5]: # 只顯示前5個
127
- print(f" - {col['name']} ({col['type']})")
128
- print("=" * 50)
129
-
130
- # in class TextToSQLSystem:
131
-
132
- def _load_gguf_model(self):
133
- """載入 GGUF 模型,使用更穩定、簡潔的參數"""
134
  try:
135
- self._log("載入 GGUF 模型 (使用穩定性參數)...")
136
- model_path = hf_hub_download(
137
- repo_id=GGUF_REPO_ID,
138
- filename=GGUF_FILENAME,
139
- repo_type="dataset"
140
- )
141
-
142
- # 使用一組更基礎、更穩定的參數來載入模型
143
- self.llm = Llama(
144
- model_path=model_path,
145
- n_ctx=2048, # 將上下文增加到 2048 以確保 Prompt 不會超長
146
- n_threads=4, # 保持 4 線程
147
- n_batch=512, # 建議值
148
- verbose=False, # 設為 False 避免 llama.cpp 本身的日誌干擾
149
- n_gpu_layers=0 # 確認在 CPU 上運行
150
- )
151
-
152
- # 簡單測試模型是否能回應
153
- self.llm("你好", max_tokens=3)
154
- self._log("✅ GGUF 模型載入成功")
155
-
156
  except Exception as e:
157
- self._log(f"❌ GGUF 載入失敗: {e}", "ERROR")
158
- self._log("系統將無法生成 SQL。請檢查模型檔案或 llama-cpp-python 安裝。", "CRITICAL")
159
  self.llm = None
160
 
161
- def _try_gguf_loading(self):
162
- """嘗試載入 GGUF"""
163
- try:
164
- model_path = hf_hub_download(
165
- repo_id=GGUF_REPO_ID,
166
- filename=GGUF_FILENAME,
167
- repo_type="dataset"
168
- )
169
-
170
- self.llm = Llama(
171
- model_path=model_path,
172
- n_ctx=512,
173
- n_threads=4,
174
- verbose=False,
175
- n_gpu_layers=0
176
- )
177
-
178
- # 測試生成
179
- test_result = self.llm("SELECT", max_tokens=5)
180
- self._log("✅ GGUF 模型載入成功")
181
- return True
182
-
183
- except Exception as e:
184
- self._log(f"GGUF 載入失敗: {e}", "WARNING")
185
- return False
186
 
187
- def _load_transformers_model(self):
188
- """使用 Transformers 載入你的微調模型"""
189
  try:
190
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
191
- import torch
192
-
193
- self._log(f"載入 Transformers 模型: {FINETUNED_MODEL_PATH}")
194
-
195
- # 載入你的微調模型
196
- self.transformers_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL_PATH)
197
- self.transformers_model = AutoModelForCausalLM.from_pretrained(
198
- FINETUNED_MODEL_PATH,
199
- torch_dtype=torch.float32, # CPU 使用 float32
200
- device_map="cpu", # 強制使用 CPU
201
- trust_remote_code=True # Qwen 模型可能需要
202
- )
203
-
204
- # 創建生成管道
205
- self.generation_pipeline = pipeline(
206
- "text-generation",
207
- model=self.transformers_model,
208
- tokenizer=self.transformers_tokenizer,
209
- device=-1, # CPU
210
- max_length=512,
211
- do_sample=True,
212
- temperature=0.1,
213
- top_p=0.9,
214
- pad_token_id=self.transformers_tokenizer.eos_token_id
215
- )
216
-
217
- self.llm = "transformers" # 標記使用 transformers
218
- self._log("✅ Transformers 模型載入成功")
219
-
220
  except Exception as e:
221
- self._log(f"❌ Transformers 載入也失敗: {e}", "ERROR")
222
  self.llm = None
223
 
224
  def huggingface_api_call(self, prompt: str) -> str:
225
- """調用 GGUF 模型,並加入詳細的原始輸出日誌"""
226
- if self.llm is None:
227
- self._log("模型未載入,返回 fallback SQL。", "ERROR")
228
- return self._generate_fallback_sql(prompt)
229
-
230
  try:
231
- output = self.llm(
232
- prompt,
233
- max_tokens=150,
234
- temperature=0.1,
235
- top_p=0.9,
236
- echo=False,
237
- # --- 將 stop 參數加回來 ---
238
- stop=["```", ";", "\n\n", "</s>"],
239
- )
240
-
241
- self._log(f"🧠 模型原始輸出 (Raw Output): {output}", "DEBUG")
242
-
243
- if output and "choices" in output and len(output["choices"]) > 0:
244
- generated_text = output["choices"][0]["text"]
245
- self._log(f"📝 提取出的生成文本: {generated_text.strip()}", "DEBUG")
246
- return generated_text.strip()
247
- else:
248
- self._log("❌ 模型的原始輸出格式不正確或為空。", "ERROR")
249
- return ""
250
-
251
  except Exception as e:
252
- self._log(f"❌ 模型生成過程中發生嚴重錯誤: {e}", "CRITICAL")
253
- import traceback
254
- self._log(traceback.format_exc(), "DEBUG")
255
  return ""
256
-
257
- def _load_gguf_model_fallback(self, model_path):
258
- """備用載入方式"""
259
- try:
260
- # 嘗試不同的參數組合
261
- self.llm = Llama(
262
- model_path=model_path,
263
- n_ctx=512, # 更小的上下文
264
- n_threads=4,
265
- n_batch=128,
266
- vocab_only=False,
267
- use_mmap=True,
268
- use_mlock=False,
269
- verbose=True
270
- )
271
- self._log("✅ 備用方式載入成功")
272
- except Exception as e:
273
- self._log(f"❌ 備用方式也失敗: {e}", "ERROR")
274
- self.llm = None
275
-
276
- def _log(self, message: str, level: str = "INFO"):
277
- self.log_history.append(format_log(message, level))
278
- print(format_log(message, level))
279
-
280
  def _load_schema(self) -> Dict:
281
- """載入數據庫結構"""
282
  try:
283
- schema_path = hf_hub_download(
284
- repo_id=DATASET_REPO_ID,
285
- filename="sqlite_schema_FULL.json",
286
- repo_type="dataset"
287
- )
288
  with open(schema_path, "r", encoding="utf-8") as f:
289
  schema_data = json.load(f)
290
-
291
- # 添加調試信息
292
- self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格:")
293
- for table_name, columns in schema_data.items():
294
- self._log(f" - {table_name}: {len(columns)} 個欄位")
295
- # 顯示前3個欄位作為範例
296
- sample_cols = [col['name'] for col in columns[:3]]
297
- self._log(f" 範例欄位: {', '.join(sample_cols)}")
298
-
299
- self._log("✅ 數據庫結構載入完成")
300
- return schema_data
301
-
302
  except Exception as e:
303
  self._log(f"❌ 載入 schema 失敗: {e}", "ERROR")
304
  return {}
305
 
306
- # 也可以添加一個方法來檢查生成的 SQL 是否使用了正確的表格和欄位
307
- def _analyze_sql_correctness(self, sql: str) -> Dict:
308
- """分析 SQL 的正確性"""
309
- analysis = {
310
- 'valid_tables': [],
311
- 'invalid_tables': [],
312
- 'valid_columns': [],
313
- 'invalid_columns': [],
314
- 'suggestions': []
315
- }
316
-
317
- if not self.schema:
318
- return analysis
319
-
320
- # 提取 SQL 中的表格名稱
321
- table_pattern = r'FROM\s+(\w+)|JOIN\s+(\w+)'
322
- table_matches = re.findall(table_pattern, sql, re.IGNORECASE)
323
- used_tables = [match[0] or match[1] for match in table_matches]
324
-
325
- # 檢查表格是否存在
326
- valid_tables = list(self.schema.keys())
327
- for table in used_tables:
328
- if table in valid_tables:
329
- analysis['valid_tables'].append(table)
330
- else:
331
- analysis['invalid_tables'].append(table)
332
- # 尋找相似的表格名稱
333
- for valid_table in valid_tables:
334
- if table.lower() in valid_table.lower() or valid_table.lower() in table.lower():
335
- analysis['suggestions'].append(f"{table} -> {valid_table}")
336
-
337
- # 提取欄位名稱(簡單版本)
338
- column_pattern = r'SELECT\s+(.*?)\s+FROM|WHERE\s+(\w+)\s*[=<>]|GROUP BY\s+(\w+)|ORDER BY\s+(\w+)'
339
- column_matches = re.findall(column_pattern, sql, re.IGNORECASE)
340
-
341
- return analysis
342
-
343
  def _encode_texts(self, texts):
344
- """編碼文本為嵌入向量"""
345
- if isinstance(texts, str):
346
- texts = [texts]
347
-
348
- inputs = self.embed_tokenizer(texts, padding=True, truncation=True,
349
- return_tensors="pt", max_length=512)
350
- if DEVICE == "cuda":
351
- inputs = {k: v.cuda() for k, v in inputs.items()}
352
-
353
  with torch.no_grad():
354
  outputs = self.embed_model(**inputs)
355
-
356
- # 使用平均池化
357
  embeddings = outputs.last_hidden_state.mean(dim=1)
358
  return embeddings.cpu()
359
 
360
  def _load_and_index_dataset(self):
361
- """載入數據集並建立 FAISS 索引"""
362
  try:
363
  dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
364
-
365
- # 先過濾不完整樣本,避免 messages 長度不足導致索引或檢索報錯
366
- try:
367
- original_count = len(dataset)
368
- except Exception:
369
- original_count = None
370
-
371
- dataset = dataset.filter(
372
- lambda ex: isinstance(ex.get("messages"), list)
373
- and len(ex["messages"]) >= 2
374
- and all(
375
- isinstance(m.get("content"), str) and m.get("content") and m["content"].strip()
376
- for m in ex["messages"][:2]
377
- )
378
- )
379
-
380
- if original_count is not None:
381
- self._log(
382
- f"資料集清理: 原始 {original_count} 筆, 過濾後 {len(dataset)} 筆, 移除 {original_count - len(dataset)} 筆"
383
- )
384
-
385
- if len(dataset) == 0:
386
- self._log("清理後資料集為空,無法建立索引。", "ERROR")
387
- return None, None
388
-
389
- corpus = [item['messages'][0]['content'] for item in dataset]
390
  self._log(f"正在編碼 {len(corpus)} 個問題...")
391
-
392
- # 批量編碼
393
- embeddings_list = []
394
- batch_size = 32
395
-
396
- for i in range(0, len(corpus), batch_size):
397
- batch_texts = corpus[i:i+batch_size]
398
- batch_embeddings = self._encode_texts(batch_texts)
399
- embeddings_list.append(batch_embeddings)
400
- self._log(f"已編碼 {min(i+batch_size, len(corpus))}/{len(corpus)}")
401
-
402
- all_embeddings = torch.cat(embeddings_list, dim=0).numpy()
403
-
404
- # 建立 FAISS 索引
405
  index = faiss.IndexFlatIP(all_embeddings.shape[1])
406
  index.add(all_embeddings.astype('float32'))
407
-
408
  self._log("✅ 向量索引建立完成")
409
  return dataset, index
410
-
411
  except Exception as e:
412
  self._log(f"❌ 載入數據失敗: {e}", "ERROR")
 
413
  return None, None
414
-
415
  def _identify_relevant_tables(self, question: str) -> List[str]:
416
- """根據實際 Schema 識別相關表格"""
417
  question_lower = question.lower()
418
  relevant_tables = []
419
-
420
- # 根據實際表格的關鍵詞映射
421
- keyword_to_table = {
422
- 'TSR53SampleDescription': ['客戶', '買方', '申請', '發票對象', 'customer', 'invoice', 'sample'],
423
- 'JobsInProgress': ['進行中', '買家', '申請方', 'buyer', 'applicant', 'progress', '工作狀態'],
424
- 'JobTimeline': ['時間', '完成', '創建', '實驗室', 'timeline', 'creation', 'lab'],
425
- 'TSR53Invoice': ['發票', '金額', '費用', 'invoice', 'credit', 'amount'],
426
- 'JobEventsLog': ['事件', '操作', '用戶', 'event', 'log', 'user'],
427
- 'calendar_days': ['工作日', '假期', 'workday', 'holiday', 'calendar']
428
- }
429
-
430
  for table, keywords in keyword_to_table.items():
431
- if any(keyword in question_lower for keyword in keywords):
432
- relevant_tables.append(table)
433
-
434
- # 預設重要表格
435
- if not relevant_tables:
436
- if any(word in question_lower for word in ['客戶', '買家', '申請', '工作單', '數量']):
437
- return ['TSR53SampleDescription', 'JobsInProgress']
438
- else:
439
- return ['JobTimeline', 'TSR53SampleDescription']
440
-
441
- return relevant_tables[:3] # 最多返回3個相關表格
442
-
443
- # 請將這整個函數複製到您的 TextToSQLSystem class 內部
444
 
445
  def _format_relevant_schema(self, table_names: List[str]) -> str:
446
- """
447
- 生成一個簡化的、不易被模型錯誤模仿的 Schema 字符串。
448
- """
449
- if not self.schema:
450
- return "No schema available.\n"
451
-
452
- actual_table_names_map = {name.lower(): name for name in self.schema.keys()}
453
- real_table_names = []
454
- for table in table_names:
455
- actual_name = actual_table_names_map.get(table.lower())
456
- if actual_name:
457
- real_table_names.append(actual_name)
458
- elif table in self.schema:
459
- real_table_names.append(table)
460
-
461
- if not real_table_names:
462
- self._log("未識別到相關表格,使用預設核心表格。", "WARNING")
463
- real_table_names = ['TSR53SampleDescription', 'JobTimeline', 'JobsInProgress']
464
-
465
  formatted = ""
466
- for table in real_table_names:
467
  if table in self.schema:
468
- # 使用簡單的 "Table: ..." 和 "Columns: ..." 格式
469
  formatted += f"Table: {table}\n"
470
  cols_str = []
471
- # 只顯示前 10 個關鍵欄位
472
  for col in self.schema[table][:10]:
473
- col_name = col['name']
474
- col_type = col['type']
475
- col_desc = col.get('description', '').replace('\n', ' ')
476
- # 將描述信息放在括號裡
477
- if col_desc:
478
- cols_str.append(f"{col_name} ({col_type}, {col_desc})")
479
- else:
480
- cols_str.append(f"{col_name} ({col_type})")
481
  formatted += f"Columns: {', '.join(cols_str)}\n\n"
482
-
483
  return formatted.strip()
484
 
485
- # in class TextToSQLSystem:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
  def _validate_and_fix_sql(self, question: str, raw_response: str) -> Tuple[Optional[str], str]:
488
- """
489
- (V23 / 统一实体识别版)
490
- 一個全面、多層次的 SQL 驗證與生成引擎。
491
- 引入了全新的、统一的实体识别引擎,能够准确解析 "买家 Gap", "c0761n",
492
- "买家ID c0761n" 等多种复杂的实体提问模式。
493
- """
494
  q_lower = question.lower()
495
-
496
- # ==============================================================================
497
- # 第一層:高價值意圖識別與模板覆寫 (Intent Recognition & Templating)
498
- # ==============================================================================
499
-
500
- # --- **全新的统一实体识别引擎** ---
501
  entity_match_data = None
502
-
503
- # 定义多种识别模式,【优先级从高到低】
504
  entity_patterns = [
505
- # 模式1: 匹配 "类型 + ID" (e.g., "买家ID C0761N") - 最高优先级
506
  {'pattern': r"(买家|buyer)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.BuyerID', 'type': '买家ID'},
507
  {'pattern': r"(申请方|申请厂商|applicant)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.ApplicantID', 'type': '申请方ID'},
508
  {'pattern': r"(付款方|付款厂商|invoiceto)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.InvoiceToID', 'type': '付款方ID'},
509
  {'pattern': r"(代理商|agent)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.AgentID', 'type': '代理商ID'},
510
-
511
- # 模式2: 匹配 "类型 + 名称" (e.g., "买家 Gap")
512
  {'pattern': r"(买家|buyer|客戶)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.BuyerName', 'type': '买家'},
513
  {'pattern': r"(申请方|申请厂商|applicant)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.ApplicantName', 'type': '申请方'},
514
  {'pattern': r"(付款方|付款厂商|invoiceto)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.InvoiceToName', 'type': '付款方'},
515
  {'pattern': r"(代理商|agent)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.AgentName', 'type': '代理商'},
516
-
517
- # 模式3: 单独匹配一个 ID (e.g., "c0761n") - 较低优先级
518
  {'pattern': r"\b([A-Z]\d{4}[A-Z])\b", 'column': 'sd.ApplicantID', 'type': 'ID'}
519
  ]
520
-
521
  for p in entity_patterns:
522
  match = re.search(p['pattern'], question, re.IGNORECASE)
523
  if match:
524
  entity_value = match.group(2) if len(match.groups()) > 1 else match.group(1)
525
- entity_match_data = {
526
- "type": p['type'],
527
- "name": entity_value.strip().upper(),
528
- "column": p['column']
529
- }
530
  break
531
 
532
- # --- 预先检测其他意图 ---
533
- job_no_match = re.search(r"(?:工單|jobno)\s*'\"?([A-Z]{2,3}\d+)'\"?", question, re.IGNORECASE)
534
-
535
- # --- 判断逻辑: 依优先级进入对应的模板 ---
536
  if any(kw in q_lower for kw in ['報告號碼', '報告清單', '列出報告', 'report number', 'list of reports']):
537
  year_match = re.search(r'(\d{4})\s*年?', question)
538
  month_match = re.search(r'(\d{1,2})\s*月', question)
@@ -540,259 +260,71 @@ class TextToSQLSystem:
540
  select_clause = "SELECT jt.JobNo, jt.ReportAuthorization"
541
  where_conditions = ["jt.ReportAuthorization IS NOT NULL"]
542
  log_parts = []
543
-
544
- if year_match: year = year_match.group(1); where_conditions.append(f"strftime('%Y', jt.ReportAuthorization) = '{year}'"); log_parts.append(f"{year}")
545
- if month_match: month = month_match.group(1).zfill(2); where_conditions.append(f"strftime('%m', jt.ReportAuthorization) = '{month}'"); log_parts.append(f"{month}月")
546
-
547
  if 'fail' in q_lower or '失敗' in q_lower:
548
- if "JOIN TSR53SampleDescription" not in from_clause: from_clause = "FROM JobTimeline AS jt JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
549
  where_conditions.append("sd.OverallRating = 'Fail'"); log_parts.append("Fail")
550
  elif 'pass' in q_lower or '通過' in q_lower:
551
- if "JOIN TSR53SampleDescription" not in from_clause: from_clause = "FROM JobTimeline AS jt JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
552
  where_conditions.append("sd.OverallRating = 'Pass'"); log_parts.append("Pass")
553
-
554
  if entity_match_data:
555
  entity_name, column_name = entity_match_data["name"], entity_match_data["column"]
556
- if "JOIN TSR53SampleDescription" not in from_clause: from_clause = "FROM JobTimeline AS jt JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
557
  match_operator = "=" if column_name.endswith("ID") else "LIKE"
558
  entity_value = f"'{entity_name}'" if match_operator == "=" else f"'%{entity_name}%'"
559
  where_conditions.append(f"{column_name} {match_operator} {entity_value}")
560
  log_parts.append(entity_name)
561
  select_clause = "SELECT jt.JobNo, sd.BuyerName, jt.ReportAuthorization"
562
-
563
- final_where_clause = "WHERE " + " AND ".join(where_conditions)
564
  time_log = " ".join(log_parts) if log_parts else "全部"
565
  self._log(f"🔄 檢測到查詢【{time_log} 報告列表】意圖,啟用智能模板。", "INFO")
566
  template_sql = f"{select_clause} {from_clause} {final_where_clause} ORDER BY jt.ReportAuthorization DESC;"
567
  return self._finalize_sql(template_sql, f"模板覆寫: {time_log} 報告列表查詢")
568
 
569
- # ... (此处可以继续添加 V17 版本中的其他所有 if/elif 模板)
570
- elif '報告' in q_lower and any(kw in q_lower for kw in ['幾份', '多少', '數量', '總數']) and not entity_match_data:
571
  year_match = re.search(r'(\d{4})\s*年?', question)
572
  time_condition, time_log = "", "總"
573
  if year_match:
574
- year = year_match.group(1)
575
- time_condition = f"WHERE ReportAuthorization IS NOT NULL AND strftime('%Y', ReportAuthorization) = '{year}'"
576
- time_log = f"{year}年"
577
  else:
578
  time_condition = "WHERE ReportAuthorization IS NOT NULL"
579
  self._log(f"🔄 檢測到查詢【{time_log}全局報告總數】意圖,啟用模板。", "INFO")
580
  template_sql = f"SELECT COUNT(DISTINCT JobNo) AS report_count FROM JobTimeline {time_condition};"
581
  return self._finalize_sql(template_sql, f"模板覆寫: {time_log}全局報告總數查詢")
582
 
583
- # ==============================================================================
584
- # 第二层:常规修正流程 (Fallback Corrections)
585
- # ==============================================================================
586
  self._log("未觸發任何模板,嘗試解析並修正 AI 輸出...", "INFO")
587
-
588
  parsed_sql = parse_sql_from_response(raw_response)
589
  if not parsed_sql:
590
- self._log(f"❌ 未能從模型回應中解析出任何 SQL。原始回應: {raw_response}", "ERROR")
591
  return None, f"無法解析SQL。原始回應:\n{raw_response}"
592
-
593
- self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG")
594
-
595
  fixed_sql = " " + parsed_sql.strip() + " "
596
  fixes_applied_fallback = []
597
-
598
  dialect_corrections = {r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)"}
599
- for pattern, replacement in dialect_corrections.items():
600
- if re.search(pattern, fixed_sql, re.IGNORECASE):
601
- fixed_sql = re.sub(pattern, replacement, fixed_sql, flags=re.IGNORECASE)
602
- fixes_applied_fallback.append(f"修正方言: {pattern}")
603
-
604
- schema_corrections = {'TSR53Report':'TSR53SampleDescription', 'TSR53InvoiceReportNo':'JobNo', 'TSR53ReportNo':'JobNo', 'TSR53InvoiceNo':'JobNo', 'TSR53InvoiceCreditNoteNo':'InvoiceCreditNoteNo', 'TSR53InvoiceLocalAmount':'LocalAmount', 'Status':'OverallRating', 'ReportStatus':'OverallRating'}
605
- for wrong, correct in schema_corrections.items():
606
- pattern = r'\b' + re.escape(wrong) + r'\b'
607
  if re.search(pattern, fixed_sql, re.IGNORECASE):
608
- fixed_sql = re.sub(pattern, correct, fixed_sql, flags=re.IGNORECASE)
609
- fixes_applied_fallback.append(f"映射 Schema: '{wrong}' -> '{correct}'")
610
-
611
  log_msg = "AI 生成並成功修正" if fixes_applied_fallback else "AI 生成且無需修正"
612
  return self._finalize_sql(fixed_sql, log_msg)
613
 
614
- def _finalize_sql(self, sql: str, log_message: str) -> Tuple[str, str]:
615
- """一個輔助函數,用於清理最終的SQL並記錄成功日誌。"""
616
- final_sql = sql.strip()
617
- if not final_sql.endswith(';'):
618
- final_sql += ';'
619
- final_sql = re.sub(r'\s+', ' ', final_sql).strip()
620
- self._log(f"✅ SQL 已生成 ({log_message})", "INFO")
621
- self._log(f" - 最終 SQL: {final_sql}", "DEBUG")
622
- return final_sql, "生成成功"
623
-
624
- def find_most_similar(self, question: str, top_k: int) -> List[Dict]:
625
- """使用 FAISS 快速檢索相似問題"""
626
- if self.faiss_index is None or self.dataset is None:
627
- return []
628
-
629
- try:
630
- # 編碼問題
631
- q_embedding = self._encode_texts([question]).numpy().astype('float32')
632
-
633
- # FAISS 搜索
634
- distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset)))
635
-
636
- results = []
637
- seen_questions = set()
638
-
639
- for i, idx in enumerate(indices[0]):
640
- if len(results) >= top_k:
641
- break
642
-
643
- # 修復:將 numpy.int64 轉換為 Python int
644
- idx = int(idx) # ← 添加這行轉換
645
-
646
- if idx >= len(self.dataset): # 確保索引有效
647
- continue
648
-
649
- item = self.dataset[idx]
650
- # 防呆:若樣本不完整則跳過
651
- if not isinstance(item.get('messages'), list) or len(item['messages']) < 2:
652
- continue
653
- q_content = (item['messages'][0].get('content') or '').strip()
654
- a_content = (item['messages'][1].get('content') or '').strip()
655
- if not q_content or not a_content:
656
- continue
657
-
658
- # 提取純淨問題
659
- clean_q = re.sub(r"以下是一個SQL查詢任務:\s*指令:\s*", "", q_content).strip()
660
- if clean_q in seen_questions:
661
- continue
662
-
663
- seen_questions.add(clean_q)
664
- sql = parse_sql_from_response(a_content) or "無法解析範例SQL"
665
-
666
- results.append({
667
- "similarity": float(distances[0][i]),
668
- "question": clean_q,
669
- "sql": sql
670
- })
671
-
672
- return results
673
-
674
- except Exception as e:
675
- self._log(f"❌ 檢索失敗: {e}", "ERROR")
676
- return []
677
-
678
- # in class TextToSQLSystem:
679
-
680
- def _build_prompt(self, user_q: str, examples: List[Dict]) -> str:
681
- """
682
- 建立一個高度結構化、以任務為導向的提示詞,使用清晰的標題分隔符。
683
- """
684
- relevant_tables = self._identify_relevant_tables(user_q)
685
-
686
- # 使用我們新的、更簡單的 schema 格式化函數
687
- schema_str = self._format_relevant_schema(relevant_tables)
688
-
689
- example_str = "No example available."
690
- if examples:
691
- best_example = examples[0]
692
- example_str = f"Question: {best_example['question']}\nSQL:\n```sql\n{best_example['sql']}\n```"
693
-
694
- # 使用強分隔符和清晰的標題來構建 prompt
695
- prompt = f"""### INSTRUCTIONS ###
696
- You are a SQLite expert. Your only job is to generate a single, valid SQLite query based on the provided schema and question.
697
- - ONLY use the tables and columns from the schema below.
698
- - ALWAYS use SQLite syntax (e.g., `strftime('%Y', date_column)` for years).
699
- - The report completion date is the `ReportAuthorization` column in the `JobTimeline` table.
700
- - Your output MUST be ONLY the SQL query inside a ```sql code block.
701
-
702
- ### SCHEMA ###
703
- {schema_str}
704
-
705
- ### EXAMPLE ###
706
- {example_str}
707
-
708
- ### TASK ###
709
- Generate a SQLite query for the following question.
710
- Question: {user_q}
711
- SQL:
712
- ```sql
713
- """
714
- self._log(f"📏 Prompt 長度: {len(prompt)} 字符")
715
- # 不再需要複雜的長度截斷邏輯,因為 schema 已經被簡化
716
- return prompt
717
-
718
-
719
- def _generate_fallback_sql(self, prompt: str) -> str:
720
- """當模型不可用時的備用 SQL 生成"""
721
- prompt_lower = prompt.lower()
722
-
723
- # 簡單的關鍵詞匹配生成基本 SQL
724
- if "統計" in prompt or "數量" in prompt or "多少" in prompt:
725
- if "月" in prompt:
726
- return "SELECT strftime('%Y-%m', completed_time) as month, COUNT(*) as count FROM jobtimeline GROUP BY month ORDER BY month;"
727
- elif "客戶" in prompt:
728
- return "SELECT applicant, COUNT(*) as count FROM tsr53sampledescription GROUP BY applicant ORDER BY count DESC;"
729
- else:
730
- return "SELECT COUNT(*) as total_count FROM jobtimeline WHERE completed_time IS NOT NULL;"
731
-
732
- elif "金額" in prompt or "總額" in prompt:
733
- return "SELECT SUM(amount) as total_amount FROM tsr53invoice;"
734
-
735
- elif "評級" in prompt or "pass" in prompt_lower or "fail" in prompt_lower:
736
- return "SELECT rating, COUNT(*) as count FROM tsr53sampledescription GROUP BY rating;"
737
-
738
- else:
739
- return "SELECT * FROM jobtimeline LIMIT 10;"
740
-
741
- def _validate_model_file(self, model_path):
742
- """驗證模型檔案完整性"""
743
- try:
744
- if not os.path.exists(model_path):
745
- return False
746
-
747
- # 檢查檔案大小(至少應該有幾MB)
748
- file_size = os.path.getsize(model_path)
749
- if file_size < 10 * 1024 * 1024: # 小於 10MB 可能有問題
750
- return False
751
-
752
- # 檢查 GGUF 檔案頭部
753
- with open(model_path, 'rb') as f:
754
- header = f.read(8)
755
- if not header.startswith(b'GGUF'):
756
- return False
757
-
758
- return True
759
- except Exception:
760
- return False
761
-
762
- # in class TextToSQLSystem:
763
-
764
  def process_question(self, question: str) -> Tuple[str, str]:
765
- """處理使用者問題 (V2 / 最終版)"""
766
- # 檢查緩存
767
- if question in self.query_cache:
768
- self._log("⚡ 使用緩存結果")
769
- return self.query_cache[question]
770
-
771
  self.log_history = []
772
  self._log(f"⏰ 處理問題: {question}")
773
-
774
- # 1. 檢索相似範例
775
- self._log("🔍 尋找相似範例...")
776
  examples = self.find_most_similar(question, FEW_SHOT_EXAMPLES_COUNT)
777
  if examples: self._log(f"✅ 找到 {len(examples)} 個相似範例")
778
-
779
- # 2. 建立提示詞
780
- self._log("📝 建立 Prompt...")
781
  prompt = self._build_prompt(question, examples)
782
-
783
- # 3. 生成 AI 回應
784
  self._log("🧠 開始生成 AI 回應...")
785
  response = self.huggingface_api_call(prompt)
786
-
787
- # 4. **新的核心步驟**: 呼叫決策引擎來生成最終 SQL
788
  final_sql, status_message = self._validate_and_fix_sql(question, response)
789
-
790
- if final_sql:
791
- result = (final_sql, status_message)
792
- else:
793
- result = (status_message, "生成失敗")
794
-
795
- # 緩存結果
796
  self.query_cache[question] = result
797
  return result
798
 
@@ -800,53 +332,36 @@ SQL:
800
  text_to_sql_system = TextToSQLSystem()
801
 
802
  def process_query(q: str):
803
- if not q.strip():
804
- return "", "等待輸入", "請輸入問題"
805
-
806
  sql, status = text_to_sql_system.process_question(q)
807
- logs = "\n".join(text_to_sql_system.log_history[-10:]) # 只顯示最後10條日誌
808
-
809
  return sql, status, logs
810
 
811
- # 範例問題
812
  examples = [
813
- "2024年每月完成多少份報告?",
814
- "統計各種評級(Pass/Fail)的分布情況",
815
- "找出總金額最高的10個工作單",
816
- "哪些客戶的工作單數量最多?",
817
- "A組昨天完成了多少個測試項目?"
 
818
  ]
819
-
820
  with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
821
- gr.Markdown("# ⚡ Text-to-SQL 智能助手")
822
- gr.Markdown("輸入自然語言問題,自動生成SQL查詢語句")
823
-
824
  with gr.Row():
825
  with gr.Column(scale=2):
826
  inp = gr.Textbox(lines=3, label="💬 您的問題", placeholder="例如:2024年每月完成多少份報告?")
827
  btn = gr.Button("🚀 生成 SQL", variant="primary")
828
  status = gr.Textbox(label="狀態", interactive=False)
829
-
830
  with gr.Column(scale=3):
831
  sql_out = gr.Code(label="🤖 生成的 SQL", language="sql", lines=8)
832
-
833
  with gr.Accordion("📋 處理日誌", open=False):
834
- logs = gr.Textbox(lines=8, label="日誌", interactive=False)
835
-
836
- # 範例區
837
- gr.Examples(
838
- examples=examples,
839
- inputs=inp,
840
- label="💡 點擊試用範例問題"
841
- )
842
-
843
- # 綁定事件
844
  btn.click(process_query, inputs=[inp], outputs=[sql_out, status, logs])
845
  inp.submit(process_query, inputs=[inp], outputs=[sql_out, status, logs])
846
 
847
  if __name__ == "__main__":
848
- demo.launch(
849
- server_name="0.0.0.0",
850
- server_port=7860,
851
- share=False
852
- )
 
1
+ # ==============================================================================
2
+ # Text-to-SQL 智能助手 - Hugging Face CPU 最终版 v6
3
+ # (融合模板引擎 + 强化 Prompt + 修复所有 Bug)
4
+ # ==============================================================================
5
  import gradio as gr
6
  import os
7
  import re
8
  import json
9
  import torch
10
  import numpy as np
11
+ import gc
12
+ import tempfile
13
  from datetime import datetime
14
  from datasets import load_dataset
15
  from huggingface_hub import hf_hub_download
16
  from llama_cpp import Llama
17
  from typing import List, Dict, Tuple, Optional
18
  import faiss
19
+ import traceback
20
 
 
21
  from transformers import AutoModel, AutoTokenizer
22
  import torch.nn.functional as F
23
 
24
+ # ==================== 配置參數 ====================
25
+ # --- Hugging Face CPU 部署配置 ---
 
 
26
  GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q8_0.gguf"
27
+ N_GPU_LAYERS = 0 # 在 Hugging Face CPU 环境下设置为 0
28
 
29
+ DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
30
+ GGUF_REPO_ID = "Paul720810/gguf-models"
 
31
  FEW_SHOT_EXAMPLES_COUNT = 1
32
+ DEVICE = "cuda" if torch.cuda.is_available() and N_GPU_LAYERS != 0 else "cpu"
33
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
34
 
35
+ TEMP_DIR = tempfile.gettempdir()
36
+ os.makedirs(os.path.join(TEMP_DIR, 'text_to_sql_cache'), exist_ok=True)
37
+
38
  print("=" * 60)
39
+ print("🤖 Text-to-SQL 智能助手 v6.0 (Hugging Face CPU 版)...")
40
+ print(f"🚀 模型: {GGUF_FILENAME}")
41
+ print(f"💻 設備: {DEVICE} (GPU Layers: {N_GPU_LAYERS})")
 
42
  print("=" * 60)
43
 
44
  # ==================== 工具函數 ====================
 
46
  return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
47
 
48
  def format_log(message: str, level: str = "INFO") -> str:
49
+ log_entry = f"[{get_current_time()}] [{level.upper()}] {message}"
50
+ print(log_entry)
51
+ return log_entry
52
 
53
  def parse_sql_from_response(response_text: str) -> Optional[str]:
54
+ if not response_text: return None
 
 
 
 
55
  response_text = response_text.strip()
 
 
56
  match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
57
+ if match: return match.group(1).strip()
 
 
 
58
  match = re.search(r"```\s*\n?(.*?)\n?```", response_text, re.DOTALL)
59
  if match:
60
  sql_candidate = match.group(1).strip()
61
+ if sql_candidate.upper().startswith('SELECT'): return sql_candidate
 
 
 
62
  match = re.search(r"(SELECT\s+.*?;)", response_text, re.DOTALL | re.IGNORECASE)
63
+ if match: return match.group(1).strip()
 
 
 
64
  match = re.search(r"(SELECT\s+.*?)(?=\n\n|\n```|$|\n[^,\s])", response_text, re.DOTALL | re.IGNORECASE)
65
  if match:
66
  sql = match.group(1).strip()
67
+ if not sql.endswith(';'): sql += ';'
 
68
  return sql
 
 
69
  if 'SELECT' in response_text.upper():
70
+ for line in response_text.split('\n'):
 
71
  line = line.strip()
72
  if line.upper().startswith('SELECT'):
73
+ if not line.endswith(';'): line += ';'
 
74
  return line
 
75
  return None
76
 
77
  # ==================== Text-to-SQL 核心類 ====================
 
80
  self.log_history = []
81
  self._log("初始化系統...")
82
  self.query_cache = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  try:
84
+ self._log(f"載入嵌入模型: {embed_model_name}")
85
+ self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
86
+ self.embed_model = AutoModel.from_pretrained(embed_model_name)
87
+ if DEVICE == "cuda":
88
+ self.embed_model.to(DEVICE)
89
+
90
+ self.schema = self._load_schema()
91
+ self.dataset, self.faiss_index = self._load_and_index_dataset()
92
+ self._load_gguf_model()
93
+ self._log("✅ 系統初始化完成")
 
 
 
 
 
 
 
 
 
 
 
94
  except Exception as e:
95
+ self._log(f"❌ 系統初始化過程中發生嚴重錯誤: {e}", "CRITICAL")
96
+ self._log(traceback.format_exc(), "DEBUG")
97
  self.llm = None
98
 
99
+ def _log(self, message: str, level: str = "INFO"):
100
+ self.log_history.append(format_log(message, level))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ def _load_gguf_model(self):
 
103
  try:
104
+ model_path = hf_hub_download(repo_id=GGUF_REPO_ID, filename=GGUF_FILENAME, repo_type="dataset", cache_dir=TEMP_DIR)
105
+ self._log(f"模型路徑: {model_path}")
106
+ self._log(f"載入 GGUF 模型 (GPU Layers: {N_GPU_LAYERS})...")
107
+ self.llm = Llama(model_path=model_path, n_ctx=2048, n_threads=4, n_batch=512, verbose=False, n_gpu_layers=N_GPU_LAYERS)
108
+ self._log("✅ GGUF 模型成功載入")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  except Exception as e:
110
+ self._log(f"❌ GGUF 載入失敗: {e}", "CRITICAL")
111
  self.llm = None
112
 
113
  def huggingface_api_call(self, prompt: str) -> str:
114
+ if self.llm is None: return ""
 
 
 
 
115
  try:
116
+ output = self.llm(prompt, max_tokens=150, temperature=0.1, top_p=0.9, echo=False, stop=["```", ";", "\n\n", "</s>", "###", "Q:"], repeat_penalty=1.1)
117
+ generated_text = output["choices"][0]["text"] if output and "choices" in output and len(output["choices"]) > 0 else ""
118
+ self._log(f"🧠 模型原始輸出: {generated_text.strip()}", "DEBUG")
119
+ return generated_text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  except Exception as e:
121
+ self._log(f"❌ 模型生成錯誤: {e}", "CRITICAL")
 
 
122
  return ""
123
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def _load_schema(self) -> Dict:
 
125
  try:
126
+ schema_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename="sqlite_schema_FULL.json", repo_type="dataset")
 
 
 
 
127
  with open(schema_path, "r", encoding="utf-8") as f:
128
  schema_data = json.load(f)
129
+ self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格。")
130
+ return schema_data
 
 
 
 
 
 
 
 
 
 
131
  except Exception as e:
132
  self._log(f"❌ 載入 schema 失敗: {e}", "ERROR")
133
  return {}
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def _encode_texts(self, texts):
136
+ if isinstance(texts, str): texts = [texts]
137
+ inputs = self.embed_tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512).to(DEVICE)
 
 
 
 
 
 
 
138
  with torch.no_grad():
139
  outputs = self.embed_model(**inputs)
 
 
140
  embeddings = outputs.last_hidden_state.mean(dim=1)
141
  return embeddings.cpu()
142
 
143
  def _load_and_index_dataset(self):
 
144
  try:
145
  dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
146
+ dataset = dataset.filter(lambda ex: isinstance(ex.get("messages"), list) and len(ex["messages"]) >= 2)
147
+ corpus = [item['messages']['content'] for item in dataset]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  self._log(f"正在編碼 {len(corpus)} 個問題...")
149
+ all_embeddings = torch.cat([self._encode_texts(corpus[i:i+32]) for i in range(0, len(corpus), 32)], dim=0).numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  index = faiss.IndexFlatIP(all_embeddings.shape[1])
151
  index.add(all_embeddings.astype('float32'))
 
152
  self._log("✅ 向量索引建立完成")
153
  return dataset, index
 
154
  except Exception as e:
155
  self._log(f"❌ 載入數據失敗: {e}", "ERROR")
156
+ self._log(traceback.format_exc(), "DEBUG")
157
  return None, None
158
+
159
  def _identify_relevant_tables(self, question: str) -> List[str]:
 
160
  question_lower = question.lower()
161
  relevant_tables = []
162
+ keyword_to_table = {'TSR53SampleDescription': ['客戶', '買方', '申請', '發票對象'], 'JobsInProgress': ['進行中', '買家', '申請方'], 'JobTimeline': ['時間', '完成', '創建', '實驗室'], 'TSR53Invoice': ['發票', '金額', '費用']}
 
 
 
 
 
 
 
 
 
 
163
  for table, keywords in keyword_to_table.items():
164
+ if any(keyword in question_lower for keyword in keywords): relevant_tables.append(table)
165
+ if not relevant_tables: return ['TSR53SampleDescription', 'JobsInProgress', 'JobTimeline']
166
+ return relevant_tables[:3]
 
 
 
 
 
 
 
 
 
 
167
 
168
  def _format_relevant_schema(self, table_names: List[str]) -> str:
169
+ if not self.schema: return "No schema available.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  formatted = ""
171
+ for table in table_names:
172
  if table in self.schema:
 
173
  formatted += f"Table: {table}\n"
174
  cols_str = []
 
175
  for col in self.schema[table][:10]:
176
+ col_name, col_type, col_desc = col['name'], col['type'], col.get('description', '').replace('\n', ' ')
177
+ if col_desc: cols_str.append(f"{col_name} ({col_type}, {col_desc})")
178
+ else: cols_str.append(f"{col_name} ({col_type})")
 
 
 
 
 
179
  formatted += f"Columns: {', '.join(cols_str)}\n\n"
 
180
  return formatted.strip()
181
 
182
+ def find_most_similar(self, question: str, top_k: int) -> List[Dict]:
183
+ if self.faiss_index is None: return []
184
+ try:
185
+ q_embedding = self._encode_texts([question]).numpy().astype('float32')
186
+ distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset)))
187
+ results, seen_questions = [], set()
188
+ for i, idx in enumerate(indices[0]):
189
+ if len(results) >= top_k: break
190
+ idx = int(idx)
191
+ if idx >= len(self.dataset): continue
192
+ item = self.dataset[idx]
193
+ if not (isinstance(item.get('messages'), list) and len(item['messages']) >= 2): continue
194
+ q_content = (item['messages']['content'] or '').strip()
195
+ a_content = (item['messages'].get('content') or '').strip()
196
+ if not q_content or not a_content: continue
197
+ clean_q = re.sub(r"以下是一個SQL查詢任務:\s*指令:\s*", "", q_content).strip()
198
+ if clean_q in seen_questions: continue
199
+ seen_questions.add(clean_q)
200
+ sql = parse_sql_from_response(a_content) or "無法解析範例SQL"
201
+ results.append({"similarity": float(distances[0][i]), "question": clean_q, "sql": sql})
202
+ return results
203
+ except Exception as e:
204
+ self._log(f"❌ 檢索失敗: {e}", "ERROR")
205
+ return []
206
+
207
+ def _build_prompt(self, user_q: str, examples: List[Dict]) -> str:
208
+ schema_str = self._format_relevant_schema(self._identify_relevant_tables(user_q))
209
+ example_str = ""
210
+ if examples:
211
+ example_prompts = [f"Q: {ex['question']}\nA: ```sql\n{ex['sql']}\n```" for ex in examples]
212
+ example_str = "\n---\n".join(example_prompts)
213
+ prompt = f"""You are an expert SQLite programmer. Your task is to generate a SQL query based on the database schema and a user's question.
214
+
215
+ ## Database Schema
216
+ {schema_str.strip()}
217
+
218
+ ## Examples
219
+ {example_str.strip()}
220
+
221
+ ## Task
222
+ Based on the schema and examples, generate the SQL query for the following question.
223
+ Q: {user_q}
224
+ A: ```sql
225
+ """
226
+ return prompt
227
+
228
+ def _finalize_sql(self, sql: str, log_message: str) -> Tuple[str, str]:
229
+ final_sql = re.sub(r'\s+', ' ', sql.strip())
230
+ if not final_sql.endswith(';'): final_sql += ';'
231
+ self._log(f"✅ SQL 已生成 ({log_message})", "INFO")
232
+ self._log(f" - 最終 SQL: {final_sql}", "DEBUG")
233
+ return final_sql, "生成成功"
234
 
235
  def _validate_and_fix_sql(self, question: str, raw_response: str) -> Tuple[Optional[str], str]:
 
 
 
 
 
 
236
  q_lower = question.lower()
 
 
 
 
 
 
237
  entity_match_data = None
 
 
238
  entity_patterns = [
 
239
  {'pattern': r"(买家|buyer)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.BuyerID', 'type': '买家ID'},
240
  {'pattern': r"(申请方|申请厂商|applicant)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.ApplicantID', 'type': '申请方ID'},
241
  {'pattern': r"(付款方|付款厂商|invoiceto)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.InvoiceToID', 'type': '付款方ID'},
242
  {'pattern': r"(代理商|agent)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.AgentID', 'type': '代理商ID'},
 
 
243
  {'pattern': r"(买家|buyer|客戶)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.BuyerName', 'type': '买家'},
244
  {'pattern': r"(申请方|申请厂商|applicant)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.ApplicantName', 'type': '申请方'},
245
  {'pattern': r"(付款方|付款厂商|invoiceto)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.InvoiceToName', 'type': '付款方'},
246
  {'pattern': r"(代理商|agent)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.AgentName', 'type': '代理商'},
 
 
247
  {'pattern': r"\b([A-Z]\d{4}[A-Z])\b", 'column': 'sd.ApplicantID', 'type': 'ID'}
248
  ]
 
249
  for p in entity_patterns:
250
  match = re.search(p['pattern'], question, re.IGNORECASE)
251
  if match:
252
  entity_value = match.group(2) if len(match.groups()) > 1 else match.group(1)
253
+ entity_match_data = {"type": p['type'], "name": entity_value.strip().upper(), "column": p['column']}
 
 
 
 
254
  break
255
 
 
 
 
 
256
  if any(kw in q_lower for kw in ['報告號碼', '報告清單', '列出報告', 'report number', 'list of reports']):
257
  year_match = re.search(r'(\d{4})\s*年?', question)
258
  month_match = re.search(r'(\d{1,2})\s*月', question)
 
260
  select_clause = "SELECT jt.JobNo, jt.ReportAuthorization"
261
  where_conditions = ["jt.ReportAuthorization IS NOT NULL"]
262
  log_parts = []
263
+ if year_match: where_conditions.append(f"strftime('%Y', jt.ReportAuthorization) = '{year_match.group(1)}'"); log_parts.append(f"{year_match.group(1)}年")
264
+ if month_match: where_conditions.append(f"strftime('%m', jt.ReportAuthorization) = '{month_match.group(1).zfill(2)}'"); log_parts.append(f"{month_match.group(1)}")
 
 
265
  if 'fail' in q_lower or '失敗' in q_lower:
266
+ if "JOIN TSR53SampleDescription" not in from_clause: from_clause += " JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
267
  where_conditions.append("sd.OverallRating = 'Fail'"); log_parts.append("Fail")
268
  elif 'pass' in q_lower or '通過' in q_lower:
269
+ if "JOIN TSR53SampleDescription" not in from_clause: from_clause += " JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
270
  where_conditions.append("sd.OverallRating = 'Pass'"); log_parts.append("Pass")
 
271
  if entity_match_data:
272
  entity_name, column_name = entity_match_data["name"], entity_match_data["column"]
273
+ if "JOIN TSR53SampleDescription" not in from_clause: from_clause += " JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
274
  match_operator = "=" if column_name.endswith("ID") else "LIKE"
275
  entity_value = f"'{entity_name}'" if match_operator == "=" else f"'%{entity_name}%'"
276
  where_conditions.append(f"{column_name} {match_operator} {entity_value}")
277
  log_parts.append(entity_name)
278
  select_clause = "SELECT jt.JobNo, sd.BuyerName, jt.ReportAuthorization"
279
+ final_where_clause = "WHERE " + " AND ".join(where_conditions) if where_conditions else ""
 
280
  time_log = " ".join(log_parts) if log_parts else "全部"
281
  self._log(f"🔄 檢測到查詢【{time_log} 報告列表】意圖,啟用智能模板。", "INFO")
282
  template_sql = f"{select_clause} {from_clause} {final_where_clause} ORDER BY jt.ReportAuthorization DESC;"
283
  return self._finalize_sql(template_sql, f"模板覆寫: {time_log} 報告列表查詢")
284
 
285
+ if '報告' in q_lower and any(kw in q_lower for kw in ['幾份', '多少', '數量', '總數']) and not entity_match_data:
 
286
  year_match = re.search(r'(\d{4})\s*年?', question)
287
  time_condition, time_log = "", "總"
288
  if year_match:
289
+ time_condition = f"WHERE ReportAuthorization IS NOT NULL AND strftime('%Y', ReportAuthorization) = '{year_match.group(1)}'"
290
+ time_log = f"{year_match.group(1)}"
 
291
  else:
292
  time_condition = "WHERE ReportAuthorization IS NOT NULL"
293
  self._log(f"🔄 檢測到查詢【{time_log}全局報告總數】意圖,啟用模板。", "INFO")
294
  template_sql = f"SELECT COUNT(DISTINCT JobNo) AS report_count FROM JobTimeline {time_condition};"
295
  return self._finalize_sql(template_sql, f"模板覆寫: {time_log}全局報告總數查詢")
296
 
 
 
 
297
  self._log("未觸發任何模板,嘗試解析並修正 AI 輸出...", "INFO")
 
298
  parsed_sql = parse_sql_from_response(raw_response)
299
  if not parsed_sql:
 
300
  return None, f"無法解析SQL。原始回應:\n{raw_response}"
 
 
 
301
  fixed_sql = " " + parsed_sql.strip() + " "
302
  fixes_applied_fallback = []
 
303
  dialect_corrections = {r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)"}
304
+ for p, r in dialect_corrections.items():
305
+ if re.search(p, fixed_sql, re.IGNORECASE):
306
+ fixed_sql = re.sub(p, r, fixed_sql, flags=re.IGNORECASE); fixes_applied_fallback.append(f"修正方言: {p}")
307
+ schema_corrections = {'TSR53Report':'TSR53SampleDescription', 'TSR53InvoiceReportNo':'JobNo', 'Status':'OverallRating'}
308
+ for w, c in schema_corrections.items():
309
+ pattern = r'\b' + re.escape(w) + r'\b'
 
 
310
  if re.search(pattern, fixed_sql, re.IGNORECASE):
311
+ fixed_sql = re.sub(pattern, c, fixed_sql, flags=re.IGNORECASE); fixes_applied_fallback.append(f"映射 Schema: '{w}' -> '{c}'")
 
 
312
  log_msg = "AI 生成並成功修正" if fixes_applied_fallback else "AI 生成且無需修正"
313
  return self._finalize_sql(fixed_sql, log_msg)
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  def process_question(self, question: str) -> Tuple[str, str]:
316
+ if question in self.query_cache: self._log("⚡ 使用緩存結果"); return self.query_cache[question]
 
 
 
 
 
317
  self.log_history = []
318
  self._log(f"⏰ 處理問題: {question}")
 
 
 
319
  examples = self.find_most_similar(question, FEW_SHOT_EXAMPLES_COUNT)
320
  if examples: self._log(f"✅ 找到 {len(examples)} 個相似範例")
 
 
 
321
  prompt = self._build_prompt(question, examples)
322
+ self._log(f"📏 Prompt 長度: {len(prompt)} 字符")
 
323
  self._log("🧠 開始生成 AI 回應...")
324
  response = self.huggingface_api_call(prompt)
 
 
325
  final_sql, status_message = self._validate_and_fix_sql(question, response)
326
+ if not final_sql: result = (status_message, "生成失敗")
327
+ else: result = (final_sql, status_message)
 
 
 
 
 
328
  self.query_cache[question] = result
329
  return result
330
 
 
332
  text_to_sql_system = TextToSQLSystem()
333
 
334
  def process_query(q: str):
335
+ if not q.strip(): return "", "等待輸入", "請輸入問題"
336
+ if text_to_sql_system.llm is None:
337
+ return "模型未能成功載入,請檢查終端日誌。", "模型載入失敗", "\n".join(text_to_sql_system.log_history)
338
  sql, status = text_to_sql_system.process_question(q)
339
+ logs = "\n".join(text_to_sql_system.log_history[-15:])
 
340
  return sql, status, logs
341
 
 
342
  examples = [
343
+ "2024年7月買家 Gap 的 Fail 報告號碼",
344
+ "列出2023年所有失败的报告",
345
+ "找出总金额最高的10个工作单",
346
+ "哪些客户的工作单数量最多?",
347
+ "A組2024年完成了多少個測試項目?",
348
+ "2024年每月完成多少份報告?"
349
  ]
 
350
  with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
351
+ gr.Markdown("# ⚡ Text-to-SQL 智能助手 (终极版)")
352
+ gr.Markdown("融合了模板引擎和 GGUF 模型的强大版本")
 
353
  with gr.Row():
354
  with gr.Column(scale=2):
355
  inp = gr.Textbox(lines=3, label="💬 您的問題", placeholder="例如:2024年每月完成多少份報告?")
356
  btn = gr.Button("🚀 生成 SQL", variant="primary")
357
  status = gr.Textbox(label="狀態", interactive=False)
 
358
  with gr.Column(scale=3):
359
  sql_out = gr.Code(label="🤖 生成的 SQL", language="sql", lines=8)
 
360
  with gr.Accordion("📋 處理日誌", open=False):
361
+ logs = gr.Textbox(lines=10, label="日誌", interactive=False)
362
+ gr.Examples(examples=examples, inputs=inp, label="💡 點擊試用範例問題")
 
 
 
 
 
 
 
 
363
  btn.click(process_query, inputs=[inp], outputs=[sql_out, status, logs])
364
  inp.submit(process_query, inputs=[inp], outputs=[sql_out, status, logs])
365
 
366
  if __name__ == "__main__":
367
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)