Paul720810 commited on
Commit
3d52e16
·
verified ·
1 Parent(s): 097112d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +545 -660
app.py CHANGED
@@ -4,6 +4,9 @@ import re
4
  import json
5
  import torch
6
  import numpy as np
 
 
 
7
  from datetime import datetime
8
  from datasets import load_dataset
9
  from huggingface_hub import hf_hub_download
@@ -11,30 +14,51 @@ from llama_cpp import Llama
11
  from typing import List, Dict, Tuple, Optional
12
  import faiss
13
  from functools import lru_cache
14
- import re
15
 
16
  # 使用 transformers 替代 sentence-transformers
17
  from transformers import AutoModel, AutoTokenizer
18
  import torch.nn.functional as F
19
 
20
- # ==================== 配置區 ====================
21
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
22
  GGUF_REPO_ID = "Paul720810/gguf-models"
23
- GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q4_k_m.gguf"
24
- #GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q8_0.gguf"
25
-
26
- # 添加這一行:你的原始微調模型路徑
27
- FINETUNED_MODEL_PATH = "Paul720810/qwen2.5-coder-1.5b-sql-finetuned" # ← 新增這行
28
 
29
- FEW_SHOT_EXAMPLES_COUNT = 2
30
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  print("=" * 60)
34
- print("🤖 Text-to-SQL 系統啟動中...")
35
- print(f"📊 數據集: {DATASET_REPO_ID}")
36
- print(f"🤖 嵌入模型: {EMBED_MODEL_NAME}")
37
- print(f"💻 設備: {DEVICE}")
 
 
38
  print("=" * 60)
39
 
40
  # ==================== 工具函數 ====================
@@ -44,57 +68,70 @@ def get_current_time():
44
  def format_log(message: str, level: str = "INFO") -> str:
45
  return f"[{get_current_time()}] [{level.upper()}] {message}"
46
 
47
- def parse_sql_from_response(response_text: str) -> Optional[str]:
48
- """更健壯的 SQL 擷取 (multi-line 安全版)"""
49
- if not response_text:
50
- return None
 
 
51
 
52
- text = response_text.strip()
 
 
 
 
53
 
54
- # 1) 取得所有 ```sql / ``` 區塊,優先使用
55
- code_blocks = re.findall(r"```(?:sql)?\s*\n([\s\S]*?)```", text, flags=re.IGNORECASE)
56
- candidates = []
57
- for block in code_blocks:
58
- b = block.strip()
59
- if 'select' in b.lower():
60
- candidates.append(b)
61
 
62
- # 2) 若無 code block,直接以正則抓第一個 SELECT...; 或到結尾
63
- if not candidates:
64
- m = re.search(r"SELECT\b[\s\S]*?(?:;|$)", text, flags=re.IGNORECASE)
65
- if m:
66
- candidates.append(m.group(0).strip())
67
 
68
- if not candidates:
 
 
69
  return None
70
 
71
- def clean(sql_raw: str) -> str:
72
- # 去除註解行與多餘空白
73
- lines = []
74
- for line in sql_raw.split('\n'):
75
- l = line.strip()
76
- if not l:
77
- continue
78
- if l.startswith('--') or l.startswith('#'):
79
- continue
80
- lines.append(l)
81
- sql_clean = ' '.join(lines)
82
- # 移除多個反引號殘留
83
- sql_clean = sql_clean.replace('```', '').strip()
84
- # 若有多個分號只保留第一個前面內容後加單一分號
85
- if ';' in sql_clean:
86
- first_part = sql_clean.split(';')[0].strip()
87
- sql_clean = first_part
88
- if not sql_clean.lower().startswith('select'):
89
- return ''
90
- if not sql_clean.endswith(';'):
91
- sql_clean += ';'
92
- return sql_clean
93
-
94
- for cand in candidates:
95
- cleaned = clean(cand)
96
- if cleaned:
97
- return cleaned
 
 
 
 
 
 
 
 
 
 
98
  return None
99
 
100
  # ==================== Text-to-SQL 核心類 ====================
@@ -103,15 +140,21 @@ class TextToSQLSystem:
103
  self.log_history = []
104
  self._log("初始化系統...")
105
  self.query_cache = {}
106
- self.backend = None # 'gguf' | 'transformers' | None
107
- self.gguf_llm = None # 實際 llama.cpp 物件
 
 
108
 
109
  # 1. 載入嵌入模型
110
  self._log(f"載入嵌入模型: {embed_model_name}")
111
  self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
112
  self.embed_model = AutoModel.from_pretrained(embed_model_name)
113
- if DEVICE == "cuda":
114
- self.embed_model = self.embed_model.cuda()
 
 
 
 
115
 
116
  # 2. 載入數據庫結構
117
  self.schema = self._load_schema()
@@ -119,220 +162,122 @@ class TextToSQLSystem:
119
  # 3. 載入數據集並建立索引
120
  self.dataset, self.faiss_index = self._load_and_index_dataset()
121
 
122
- # 4. 載入 GGUF 模型(添加錯誤處理)
123
  self._load_gguf_model()
124
 
125
- self._log("系統初始化完成")
126
- # 載入數據庫結構
127
- self.schema = self._load_schema()
128
-
129
- # 暫時添加:打印 schema 信息
130
- if self.schema:
131
- print("=" * 50)
132
- print("數據庫 Schema 信息:")
133
- for table_name, columns in self.schema.items():
134
- print(f"\n表格: {table_name}")
135
- print(f"欄位數: {len(columns)}")
136
- print("欄位列表:")
137
- for col in columns[:5]: # 只顯示前5個
138
- print(f" - {col['name']} ({col['type']})")
139
- print("=" * 50)
140
 
141
- # in class TextToSQLSystem:
 
 
142
 
143
  def _load_gguf_model(self):
144
- """載入 GGUF 模型,使用更穩定、簡潔的參數"""
145
  try:
146
- self._log("載入 GGUF 模型 (使用穩定性參數)...")
147
- model_path = hf_hub_download(
148
- repo_id=GGUF_REPO_ID,
149
- filename=GGUF_FILENAME,
150
- repo_type="dataset"
151
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- # 使用一組更基礎、更穩定的參數來載入模型
154
- self.gguf_llm = Llama(
 
 
155
  model_path=model_path,
156
- n_ctx=2048, # 將上下文增加到 2048 以確保 Prompt 不會超長
157
- n_threads=4, # 保持 4 線程
158
- n_batch=512, # 建議值
159
- verbose=False, # 設為 False 避免 llama.cpp 本身的日誌干擾
160
- n_gpu_layers=0 # 確認在 CPU 上運行
 
 
 
161
  )
162
 
163
- # 簡單測試模型是否能回應
164
- self.gguf_llm("你好", max_tokens=3)
165
- self.backend = "gguf"
166
- self._log("✅ GGUF 模型載入成功")
 
 
167
 
168
  except Exception as e:
169
- self._log(f"GGUF 載入失敗: {e}", "ERROR")
170
- self._log("系統將無法生成 SQL。請檢查模型檔案或 llama-cpp-python 安裝。", "CRITICAL")
171
  self.llm = None
172
 
173
- def _try_gguf_loading(self):
174
- """嘗試載入 GGUF"""
175
  try:
176
- model_path = hf_hub_download(
177
- repo_id=GGUF_REPO_ID,
178
- filename=GGUF_FILENAME,
179
- repo_type="dataset"
180
- )
181
 
182
- self.gguf_llm = Llama(
183
- model_path=model_path,
184
- n_ctx=512,
185
- n_threads=4,
186
- verbose=False,
187
- n_gpu_layers=0
188
- )
189
 
190
- # 測試生成
191
- test_result = self.gguf_llm("SELECT", max_tokens=5)
192
- self._log("✅ GGUF 模型載入成功")
193
- return True
 
194
 
195
- except Exception as e:
196
- self._log(f"GGUF 載入失敗: {e}", "WARNING")
197
  return False
198
 
199
- def _load_transformers_model(self):
200
- """使用 Transformers 載入你的微調模型"""
 
 
 
 
201
  try:
202
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
203
- import torch
204
-
205
- self._log(f"載入 Transformers 模型: {FINETUNED_MODEL_PATH}")
206
-
207
- # 載入你的微調模型
208
- self.transformers_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL_PATH)
209
- self.transformers_model = AutoModelForCausalLM.from_pretrained(
210
- FINETUNED_MODEL_PATH,
211
- torch_dtype=torch.float32, # CPU 使用 float32
212
- device_map="cpu", # 強制使用 CPU
213
- trust_remote_code=True # Qwen 模型可能需要
214
- )
215
 
216
- # 創建生成管道
217
- self.generation_pipeline = pipeline(
218
- "text-generation",
219
- model=self.transformers_model,
220
- tokenizer=self.transformers_tokenizer,
221
- device=-1, # CPU
222
- max_length=512,
223
- do_sample=True,
224
  temperature=0.1,
225
  top_p=0.9,
226
- pad_token_id=self.transformers_tokenizer.eos_token_id
 
227
  )
228
 
229
- # 標記目前後端為 transformers
230
- self.backend = "transformers"
231
- self._log("✅ Transformers 模型載入成功")
232
-
233
- except Exception as e:
234
- self._log(f"❌ Transformers 載入也失敗: {e}", "ERROR")
235
-
236
- def huggingface_api_call(self, prompt: str) -> str:
237
- """生成 SQL:優先使用 transformers,其次 gguf,最後 fallback"""
238
- # transformers 後端
239
- if self.backend == "transformers" and hasattr(self, "generation_pipeline"):
240
- try:
241
- gen = self.generation_pipeline(
242
- prompt,
243
- max_new_tokens=350,
244
- do_sample=True,
245
- temperature=0.05,
246
- top_p=0.9
247
- )
248
- # 盡量從 pipeline 結果提取文字
249
- generated_text = ""
250
- try:
251
- if isinstance(gen, list) and gen:
252
- first = gen[0]
253
- if isinstance(first, dict) and "generated_text" in first:
254
- generated_text = str(first["generated_text"]) # type: ignore[index]
255
- else:
256
- generated_text = str(first)
257
- else:
258
- generated_text = str(gen)
259
- except Exception:
260
- generated_text = str(gen)
261
- # 若包含 prompt,裁切前綴
262
- if isinstance(generated_text, str) and generated_text.startswith(prompt):
263
- generated_text = generated_text[len(prompt):]
264
- self._log(f"📝 提取出的生成文本: {generated_text.strip()}", "DEBUG")
265
-
266
- lines = generated_text.strip().split('\n')
267
- non_comment_lines = [line for line in lines if not line.strip().startswith('--')]
268
- cleaned_text = "\n".join(non_comment_lines).strip()
269
- if cleaned_text != generated_text.strip():
270
- self._log(f"🧹 清理掉註解後的文本: {cleaned_text}", "DEBUG")
271
- if cleaned_text and not re.match(r"^\s*select\b", cleaned_text, flags=re.IGNORECASE):
272
- self._log("⚙️ 補上缺失的 'SELECT ' 起手以形成完整查詢", "DEBUG")
273
- cleaned_text = "SELECT " + cleaned_text.lstrip()
274
- return cleaned_text
275
- except Exception as e:
276
- self._log(f"❌ Transformers 生成失敗: {e}", "ERROR")
277
- return ""
278
 
279
- # gguf 後端
280
- if self.backend == "gguf" and self.gguf_llm is not None and callable(getattr(self.gguf_llm, "__call__", None)):
281
- try:
282
- output = self.gguf_llm(
283
- prompt,
284
- max_tokens=350,
285
- temperature=0.05,
286
- top_p=0.9,
287
- echo=False,
288
- stop=["```"]
289
- )
290
- self._log(f"🧠 模型原始輸出 (Raw Output): {output}", "DEBUG")
291
- if output and "choices" in output and len(output["choices"]) > 0:
292
- generated_text = output["choices"][0]["text"]
293
- self._log(f"📝 提取出的生成文本: {generated_text.strip()}", "DEBUG")
294
- lines = str(generated_text).strip().split('\n')
295
- non_comment_lines = [line for line in lines if not line.strip().startswith('--')]
296
- cleaned_text = "\n".join(non_comment_lines).strip()
297
- if cleaned_text != str(generated_text).strip():
298
- self._log(f"🧹 清理掉註解後的文本: {cleaned_text}", "DEBUG")
299
- if cleaned_text and not re.match(r"^\s*select\b", cleaned_text, flags=re.IGNORECASE):
300
- self._log("⚙️ 補上缺失的 'SELECT ' 起手以形成完整查詢", "DEBUG")
301
- cleaned_text = "SELECT " + cleaned_text.lstrip()
302
- return cleaned_text
303
- else:
304
- self._log("❌ 模型的原始輸出格式不正確或為空。", "ERROR")
305
- return ""
306
- except Exception as e:
307
- self._log(f"❌ GGUF 生成失敗: {e}", "ERROR")
308
  return ""
309
 
310
- # 後備:都不可用時,回退
311
- self._log("模型未載入或不可用,返回 fallback SQL。", "ERROR")
312
- return self._generate_fallback_sql(prompt)
313
-
314
- def _load_gguf_model_fallback(self, model_path):
315
- """備用載入方式"""
316
- try:
317
- # 嘗試不同的參數組合
318
- self.gguf_llm = Llama(
319
- model_path=model_path,
320
- n_ctx=512, # 更小的上下文
321
- n_threads=4,
322
- n_batch=128,
323
- vocab_only=False,
324
- use_mmap=True,
325
- use_mlock=False,
326
- verbose=True
327
- )
328
- self._log("✅ 備用方式載入成功")
329
  except Exception as e:
330
- self._log(f" 備用方式也失敗: {e}", "ERROR")
331
- self.gguf_llm = None
332
-
333
- def _log(self, message: str, level: str = "INFO"):
334
- self.log_history.append(format_log(message, level))
335
- print(format_log(message, level))
336
 
337
  def _load_schema(self) -> Dict:
338
  """載入數據庫結構"""
@@ -340,91 +285,58 @@ class TextToSQLSystem:
340
  schema_path = hf_hub_download(
341
  repo_id=DATASET_REPO_ID,
342
  filename="sqlite_schema_FULL.json",
343
- repo_type="dataset"
 
344
  )
345
  with open(schema_path, "r", encoding="utf-8") as f:
346
  schema_data = json.load(f)
347
 
348
- # 添加調試信息
349
- self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格:")
350
  for table_name, columns in schema_data.items():
351
  self._log(f" - {table_name}: {len(columns)} 個欄位")
352
- # 顯示前3個欄位作為範例
353
- sample_cols = [col['name'] for col in columns[:3]]
354
- self._log(f" 範例欄位: {', '.join(sample_cols)}")
355
 
356
- self._log("數據庫結構載入完成")
357
  return schema_data
358
 
359
  except Exception as e:
360
- self._log(f"載入 schema 失敗: {e}", "ERROR")
361
  return {}
362
 
363
- # 也可以添加一個方法來檢查生成的 SQL 是否使用了正確的表格和欄位
364
- def _analyze_sql_correctness(self, sql: str) -> Dict:
365
- """分析 SQL 的正確性"""
366
- analysis = {
367
- 'valid_tables': [],
368
- 'invalid_tables': [],
369
- 'valid_columns': [],
370
- 'invalid_columns': [],
371
- 'suggestions': []
372
- }
373
-
374
- if not self.schema:
375
- return analysis
376
-
377
- # 提取 SQL 中的表格名稱
378
- table_pattern = r'FROM\s+(\w+)|JOIN\s+(\w+)'
379
- table_matches = re.findall(table_pattern, sql, re.IGNORECASE)
380
- used_tables = [match[0] or match[1] for match in table_matches]
381
-
382
- # 檢查表格是否存在
383
- valid_tables = list(self.schema.keys())
384
- for table in used_tables:
385
- if table in valid_tables:
386
- analysis['valid_tables'].append(table)
387
- else:
388
- analysis['invalid_tables'].append(table)
389
- # 尋找相似的表格名稱
390
- for valid_table in valid_tables:
391
- if table.lower() in valid_table.lower() or valid_table.lower() in table.lower():
392
- analysis['suggestions'].append(f"{table} -> {valid_table}")
393
-
394
- # 提取欄位名稱(簡單版本)
395
- column_pattern = r'SELECT\s+(.*?)\s+FROM|WHERE\s+(\w+)\s*[=<>]|GROUP BY\s+(\w+)|ORDER BY\s+(\w+)'
396
- column_matches = re.findall(column_pattern, sql, re.IGNORECASE)
397
-
398
- return analysis
399
-
400
  def _encode_texts(self, texts):
401
  """編碼文本為嵌入向量"""
402
  if isinstance(texts, str):
403
  texts = [texts]
404
-
405
  inputs = self.embed_tokenizer(texts, padding=True, truncation=True,
406
- return_tensors="pt", max_length=512)
407
- if DEVICE == "cuda":
408
- inputs = {k: v.cuda() for k, v in inputs.items()}
 
 
 
409
 
410
  with torch.no_grad():
411
  outputs = self.embed_model(**inputs)
412
 
413
  # 使用平均池化
414
  embeddings = outputs.last_hidden_state.mean(dim=1)
415
- return embeddings.cpu()
416
 
417
  def _load_and_index_dataset(self):
418
  """載入數據集並建立 FAISS 索引"""
419
  try:
420
- dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
 
 
421
 
422
- # 先過濾不完整樣本,避免 messages 長度不足導致索引或檢索報錯
423
- try:
424
- original_count = len(dataset)
425
- except Exception:
426
- original_count = None
 
427
 
 
 
428
  dataset = dataset.filter(
429
  lambda ex: isinstance(ex.get("messages"), list)
430
  and len(ex["messages"]) >= 2
@@ -434,10 +346,7 @@ class TextToSQLSystem:
434
  )
435
  )
436
 
437
- if original_count is not None:
438
- self._log(
439
- f"資料集清理: 原始 {original_count} 筆, 過濾後 {len(dataset)} 筆, 移除 {original_count - len(dataset)} 筆"
440
- )
441
 
442
  if len(dataset) == 0:
443
  self._log("清理後資料集為空,無法建立索引。", "ERROR")
@@ -446,14 +355,19 @@ class TextToSQLSystem:
446
  corpus = [item['messages'][0]['content'] for item in dataset]
447
  self._log(f"正在編碼 {len(corpus)} 個問題...")
448
 
449
- # 批量編碼
450
  embeddings_list = []
451
- batch_size = 32
452
 
453
  for i in range(0, len(corpus), batch_size):
454
  batch_texts = corpus[i:i+batch_size]
455
  batch_embeddings = self._encode_texts(batch_texts)
456
  embeddings_list.append(batch_embeddings)
 
 
 
 
 
457
  self._log(f"已編碼 {min(i+batch_size, len(corpus))}/{len(corpus)}")
458
 
459
  all_embeddings = torch.cat(embeddings_list, dim=0).numpy()
@@ -462,11 +376,15 @@ class TextToSQLSystem:
462
  index = faiss.IndexFlatIP(all_embeddings.shape[1])
463
  index.add(all_embeddings.astype('float32'))
464
 
465
- self._log("✅ 向量索引建立完成")
 
 
 
 
466
  return dataset, index
467
 
468
  except Exception as e:
469
- self._log(f"載入數據失敗: {e}", "ERROR")
470
  return None, None
471
 
472
  def _identify_relevant_tables(self, question: str) -> List[str]:
@@ -497,12 +415,8 @@ class TextToSQLSystem:
497
 
498
  return relevant_tables[:3] # 最多返回3個相關表格
499
 
500
- # 請將這整個函數複製到您的 TextToSQLSystem class 內部
501
-
502
  def _format_relevant_schema(self, table_names: List[str]) -> str:
503
- """
504
- 生成一個簡化的、不易被模型錯誤模仿的 Schema 字符串。
505
- """
506
  if not self.schema:
507
  return "No schema available.\n"
508
 
@@ -522,257 +436,17 @@ class TextToSQLSystem:
522
  formatted = ""
523
  for table in real_table_names:
524
  if table in self.schema:
525
- # 使用簡單的 "Table: ..." 和 "Columns: ..." 格式
526
  formatted += f"Table: {table}\n"
527
  cols_str = []
528
- # 只顯示前 10 個關鍵欄位
529
- for col in self.schema[table][:10]:
530
  col_name = col['name']
531
  col_type = col['type']
532
- col_desc = col.get('description', '').replace('\n', ' ')
533
- # 將描述信息放在括號裡
534
- if col_desc:
535
- cols_str.append(f"{col_name} ({col_type}, {col_desc})")
536
- else:
537
- cols_str.append(f"{col_name} ({col_type})")
538
  formatted += f"Columns: {', '.join(cols_str)}\n\n"
539
 
540
  return formatted.strip()
541
 
542
- # 在 class TextToSQLSystem 內
543
-
544
- def _validate_and_fix_sql(self, question: str, raw_response: str) -> Tuple[Optional[str], str]:
545
- """
546
- (V29 / 穩健正則 + 智能計數 最終版)
547
- 一個多層次的SQL生成引擎。它優先使用基於規則的動態模板生成器,
548
- 如果無法匹配,則回退到解析和修正AI模型的輸出。
549
- - 使用更簡潔、穩健的正則表達式來捕獲實體名稱。
550
- - 根據問題是關於「報告」還是「測試項目」來智能地決定計數目標。
551
- """
552
- q_lower = question.lower()
553
-
554
- # ==============================================================================
555
- # 第零層:統一實體識別引擎 (Unified Entity Recognition Engine)
556
- # ==============================================================================
557
- entity_match_data = None
558
- # 包含了繁簡體兼容和更穩健的模式
559
- entity_patterns = [
560
- # 模式1: 匹配 "类型 + ID" - (保持不變)
561
- {'pattern': r"(買家|买家|buyer)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.BuyerID', 'type': '買家ID'},
562
- {'pattern': r"(申請方|申请方|申請廠商|申请厂商|applicant)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.ApplicantID', 'type': '申請方ID'},
563
- {'pattern': r"(付款方|付款厂商|invoiceto)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.InvoiceToID', 'type': '付款方ID'},
564
- {'pattern': r"(代理商|agent)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.AgentID', 'type': '代理商ID'},
565
-
566
- # 模式2: 匹配 "類型 + 名稱" - (簡化了模式,使其更穩健)
567
- {'pattern': r"(買家|买家|buyer|客戶)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.BuyerName', 'type': '買家'},
568
- {'pattern': r"(申請方|申请方|申請廠商|申请厂商|applicant)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.ApplicantName', 'type': '申請方'},
569
- {'pattern': r"(付款方|付款厂商|invoiceto)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.InvoiceToName', 'type': '付款方'},
570
- {'pattern': r"(代理商|agent)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.AgentName', 'type': '代理商'},
571
-
572
- # 模式3: 单独匹配一个 ID - (保持不變)
573
- {'pattern': r"\b([A-Z]\d{4}[A-Z])\b", 'column': 'sd.ApplicantID', 'type': 'ID'}
574
- ]
575
-
576
- for p in entity_patterns:
577
- match = re.search(p['pattern'], question, re.IGNORECASE)
578
- if match:
579
- entity_value = match.group(2) if len(match.groups()) > 1 else match.group(1)
580
- entity_match_data = {
581
- "type": p['type'],
582
- "name": entity_value.strip().upper(),
583
- "column": p['column']
584
- }
585
- break
586
-
587
- # ==============================================================================
588
- # 第一層:模組化意圖偵測與動態SQL組合
589
- # ==============================================================================
590
-
591
- intents = {}
592
- sql_components = {
593
- 'select': [], 'from': "", 'joins': [], 'where': [],
594
- 'group_by': [], 'order_by': [], 'log_parts': []
595
- }
596
-
597
- # --- 運行一系列獨立的意圖偵測器 ---
598
-
599
- # 偵測器 2.1: 核心動作意圖
600
- if any(kw in q_lower for kw in ['幾份', '多少', '數量', '總數', 'how many', 'count']):
601
- intents['action'] = 'count'
602
- # 智能決定計數目標
603
- if "測試項目" in question or "test item" in q_lower:
604
- sql_components['select'].append("COUNT(jip.ItemCode) AS item_count")
605
- sql_components['log_parts'].append("測試項目總數")
606
- else: # 預設是計數報告
607
- sql_components['select'].append("COUNT(DISTINCT jt.JobNo) AS report_count")
608
- sql_components['log_parts'].append("報告總數")
609
- elif any(kw in q_lower for kw in ['報告號碼', '報告清單', '列出報告', 'report number', 'list of reports']):
610
- intents['action'] = 'list'
611
- sql_components['select'].append("jt.JobNo, jt.ReportAuthorization")
612
- sql_components['order_by'].append("jt.ReportAuthorization DESC")
613
- sql_components['log_parts'].append("報告列表")
614
-
615
- # 偵測器 2.2: 時間意圖
616
- year_match = re.search(r'(\d{4})\s*年?', question)
617
- month_match = re.search(r'(\d{1,2})\s*月', question)
618
- if year_match:
619
- year = year_match.group(1)
620
- sql_components['where'].append(f"strftime('%Y', jt.ReportAuthorization) = '{year}'")
621
- sql_components['log_parts'].append(f"{year}年")
622
- if month_match:
623
- month = month_match.group(1).zfill(2)
624
- sql_components['where'].append(f"strftime('%m', jt.ReportAuthorization) = '{month}'")
625
- sql_components['log_parts'].append(f"{month}月")
626
-
627
- # 偵測器 2.3: 實體意圖
628
- if entity_match_data:
629
- if "TSR53SampleDescription" not in " ".join(sql_components['joins']):
630
- sql_components['joins'].append("JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo")
631
- entity_name, column_name = entity_match_data["name"], entity_match_data["column"]
632
- match_operator = "=" if column_name.endswith("ID") else "LIKE"
633
- entity_value = f"'%{entity_name}%'" if match_operator == "LIKE" else f"'{entity_name}'"
634
- sql_components['where'].append(f"{column_name} {match_operator} {entity_value}")
635
- sql_components['log_parts'].append(entity_match_data["type"] + ":" + entity_name)
636
- if intents.get('action') == 'list':
637
- sql_components['select'].append("sd.BuyerName")
638
-
639
- # 偵測器 2.4: 評級意圖
640
- if 'fail' in q_lower or '失敗' in q_lower:
641
- if "TSR53SampleDescription" not in " ".join(sql_components['joins']):
642
- sql_components['joins'].append("JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo")
643
- sql_components['where'].append("sd.OverallRating = 'Fail'")
644
- sql_components['log_parts'].append("Fail")
645
- elif 'pass' in q_lower or '通過' in q_lower:
646
- if "TSR53SampleDescription" not in " ".join(sql_components['joins']):
647
- sql_components['joins'].append("JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo")
648
- sql_components['where'].append("sd.OverallRating = 'Pass'")
649
- sql_components['log_parts'].append("Pass")
650
-
651
- # 偵測器 2.5: 實驗組 (LabGroup) 意圖 (帶有別名映射)
652
- lab_group_mapping = {'A': 'TA', 'B': 'TB', 'C': 'TC', 'D': 'TD', 'E': 'TE', 'Y': 'TY'}
653
- lab_group_match = re.search(r'([A-Z]{1,2})組', question, re.IGNORECASE)
654
- if lab_group_match:
655
- user_input_group = lab_group_match.group(1).upper()
656
- db_lab_group = lab_group_mapping.get(user_input_group, user_input_group)
657
- sql_components['joins'].append("JOIN JobItemsInProgress AS jip ON jt.JobNo = jip.JobNo")
658
- sql_components['where'].append(f"jip.LabGroup = '{db_lab_group}'")
659
- sql_components['log_parts'].append(f"{user_input_group}組(->{db_lab_group})")
660
-
661
- # --- 2.6: 兩年份比較模板(優先級:高) ---
662
- # 偵測『比較/vs/對比/相較/相比』字樣,擷取兩個年份與(可選)買家名稱
663
- compare_hit = any(kw in q_lower for kw in ["比較", "對比", "相較", "相比", "vs", "versus"])
664
- years_found = re.findall(r"(20\d{2})", question)
665
- years_unique = []
666
- for y in years_found:
667
- if y not in years_unique:
668
- years_unique.append(y)
669
- if compare_hit and len(years_unique) >= 2:
670
- year_a, year_b = years_unique[0], years_unique[1]
671
- # 嘗試抓買家名稱(英文/數字/符號),若沒有則不加 buyer 條件
672
- buyer_name = None
673
- # 1) 優先解析明確條件:BuyerName LIKE '%...%'
674
- m_like = re.search(r"BuyerName\s+LIKE\s*'%([^']+)%'", question, re.IGNORECASE)
675
- if m_like:
676
- buyer_name = m_like.group(1).strip()
677
- else:
678
- # 2) 解析自然語言:避免 'BuyerName' 被誤判成 'buyer'
679
- buyer_match = re.search(r"(?:買家|买家|客戶|客户|\bbuyer\b(?!name))\s*[::]?\s*([A-Za-z0-9&.\- ]+)", question, re.IGNORECASE)
680
- if buyer_match:
681
- buyer_name = buyer_match.group(1).strip()
682
-
683
- # 判斷偏向金額或件數
684
- amount_intent = any(kw in q_lower for kw in ["金額", "金钱", "amount", "營收", "業績", "營業額", "銷售額", "revenue"])
685
-
686
- if amount_intent:
687
- # 金額版:需要發票表,依架構命名使用 TSR53Invoice 與 LocalAmount;與樣本描述以 JobNo 關聯
688
- sql = (
689
- "SELECT strftime('%Y', jt.ReportAuthorization) AS year, "
690
- "SUM(COALESCE(inv.LocalAmount, 0)) AS total_amount "
691
- "FROM JobTimeline AS jt "
692
- "JOIN TSR53SampleDescription AS sd ON sd.JobNo = jt.JobNo "
693
- "LEFT JOIN TSR53Invoice AS inv ON inv.JobNo = jt.JobNo "
694
- "WHERE jt.ReportAuthorization IS NOT NULL "
695
- f"AND strftime('%Y', jt.ReportAuthorization) IN ('{year_a}', '{year_b}') "
696
- )
697
- if buyer_name:
698
- sql += f"AND sd.BuyerName LIKE '%{buyer_name}%' "
699
- sql += "GROUP BY year ORDER BY year;"
700
- return self._finalize_sql(sql, f"模板覆寫: 兩年份金額比較 {year_a} vs {year_b}" )
701
- else:
702
- # 件數版:以報告數量為主,去重 JobNo
703
- sql = (
704
- "SELECT strftime('%Y', jt.ReportAuthorization) AS year, "
705
- "COUNT(DISTINCT jt.JobNo) AS report_count "
706
- "FROM JobTimeline AS jt "
707
- "JOIN TSR53SampleDescription AS sd ON sd.JobNo = jt.JobNo "
708
- "WHERE jt.ReportAuthorization IS NOT NULL "
709
- f"AND strftime('%Y', jt.ReportAuthorization) IN ('{year_a}', '{year_b}') "
710
- )
711
- if buyer_name:
712
- sql += f"AND sd.BuyerName LIKE '%{buyer_name}%' "
713
- sql += "GROUP BY year ORDER BY year;"
714
- return self._finalize_sql(sql, f"模板覆寫: 兩年份件數比較 {year_a} vs {year_b}" )
715
-
716
- # --- 3. 判斷是否觸發了模板,並動態組合 SQL ---
717
- if 'action' in intents:
718
- sql_components['from'] = "FROM JobTimeline AS jt"
719
- # 只要有任何篩選條件,就加上報告已授權的基礎限制
720
- if sql_components['where']:
721
- sql_components['where'].insert(0, "jt.ReportAuthorization IS NOT NULL")
722
-
723
- select_clause = "SELECT " + ", ".join(sorted(list(set(sql_components['select']))))
724
- from_clause = sql_components['from']
725
- joins_clause = " ".join(sql_components['joins'])
726
- where_clause = "WHERE " + " AND ".join(sql_components['where']) if sql_components['where'] else ""
727
- orderby_clause = "ORDER BY " + ", ".join(sql_components['order_by']) if sql_components['order_by'] else ""
728
-
729
- template_sql = f"{select_clause} {from_clause} {joins_clause} {where_clause} {orderby_clause};"
730
-
731
- query_log = " ".join(sql_components['log_parts'])
732
- self._log(f"🔄 偵測到組合意圖【{query_log}】,啟用動態模板。", "INFO")
733
- return self._finalize_sql(template_sql, f"模板覆寫: {query_log} 查詢")
734
-
735
- # ==============================================================================
736
- # 第二层:AI 生成修正流程 (Fallback)
737
- # ==============================================================================
738
- self._log("未觸發任何模板,嘗試解析並修正 AI 輸出...", "INFO")
739
-
740
- parsed_sql = parse_sql_from_response(raw_response)
741
- if not parsed_sql:
742
- self._log(f"❌ 未能從模型回應中解析出任何 SQL。原始回應: {raw_response}", "ERROR")
743
- return None, f"無法解析SQL。原始回應:\n{raw_response}"
744
-
745
- self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG")
746
-
747
- fixed_sql = " " + parsed_sql.strip() + " "
748
- fixes_applied_fallback = []
749
-
750
- dialect_corrections = {r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)"}
751
- for pattern, replacement in dialect_corrections.items():
752
- if re.search(pattern, fixed_sql, re.IGNORECASE):
753
- fixed_sql = re.sub(pattern, replacement, fixed_sql, flags=re.IGNORECASE)
754
- fixes_applied_fallback.append(f"修正方言: {pattern}")
755
-
756
- schema_corrections = {'TSR53Report':'TSR53SampleDescription', 'TSR53InvoiceReportNo':'JobNo', 'TSR53ReportNo':'JobNo', 'TSR53InvoiceNo':'JobNo', 'TSR53InvoiceCreditNoteNo':'InvoiceCreditNoteNo', 'TSR53InvoiceLocalAmount':'LocalAmount', 'Status':'OverallRating', 'ReportStatus':'OverallRating'}
757
- for wrong, correct in schema_corrections.items():
758
- pattern = r'\b' + re.escape(wrong) + r'\b'
759
- if re.search(pattern, fixed_sql, re.IGNORECASE):
760
- fixed_sql = re.sub(pattern, correct, fixed_sql, flags=re.IGNORECASE)
761
- fixes_applied_fallback.append(f"映射 Schema: '{wrong}' -> '{correct}'")
762
-
763
- log_msg = "AI 生成並成功修正" if fixes_applied_fallback else "AI 生成且無需修正"
764
- return self._finalize_sql(fixed_sql, log_msg)
765
-
766
- def _finalize_sql(self, sql: str, log_message: str) -> Tuple[str, str]:
767
- """一個輔助函數,用於清理最終的SQL並記錄成功日誌。"""
768
- final_sql = sql.strip()
769
- if not final_sql.endswith(';'):
770
- final_sql += ';'
771
- final_sql = re.sub(r'\s+', ' ', final_sql).strip()
772
- self._log(f"✅ SQL 已生成 ({log_message})", "INFO")
773
- self._log(f" - 最終 SQL: {final_sql}", "DEBUG")
774
- return final_sql, "生成成功"
775
-
776
  def find_most_similar(self, question: str, top_k: int) -> List[Dict]:
777
  """使用 FAISS 快速檢索相似問題"""
778
  if self.faiss_index is None or self.dataset is None:
@@ -792,16 +466,14 @@ class TextToSQLSystem:
792
  if len(results) >= top_k:
793
  break
794
 
795
- # 修復:將 numpy.int64 轉換為 Python int
796
- idx = int(idx) # ← 添加這行轉換
797
-
798
- if idx >= len(self.dataset): # 確保索引有效
799
  continue
800
 
801
  item = self.dataset[idx]
802
- # 防呆:若樣本不完整則跳過
803
  if not isinstance(item.get('messages'), list) or len(item['messages']) < 2:
804
  continue
 
805
  q_content = (item['messages'][0].get('content') or '').strip()
806
  a_content = (item['messages'][1].get('content') or '').strip()
807
  if not q_content or not a_content:
@@ -824,18 +496,12 @@ class TextToSQLSystem:
824
  return results
825
 
826
  except Exception as e:
827
- self._log(f"檢索失敗: {e}", "ERROR")
828
  return []
829
 
830
- # in class TextToSQLSystem:
831
-
832
  def _build_prompt(self, user_q: str, examples: List[Dict]) -> str:
833
- """
834
- 建立一個高度結構化、以任務為導向的提示詞,使用清晰的標題分隔符。
835
- """
836
  relevant_tables = self._identify_relevant_tables(user_q)
837
-
838
- # 使用我們新的、更簡單的 schema 格式化函數
839
  schema_str = self._format_relevant_schema(relevant_tables)
840
 
841
  example_str = "No example available."
@@ -843,8 +509,9 @@ class TextToSQLSystem:
843
  best_example = examples[0]
844
  example_str = f"Question: {best_example['question']}\nSQL:\n```sql\n{best_example['sql']}\n```"
845
 
846
- # 使用強分隔符和清晰的標題來構建 prompt
847
- prompt = f"""You are a silent SQL query generator. You are physically incapable of producing any text that is not a valid SQLite query. You will be penalized for any explanation or comment. Your entire existence is to translate a user's question into a single SQLite query.
 
848
 
849
  ### SCHEMA ###
850
  {schema_str}
@@ -852,22 +519,241 @@ class TextToSQLSystem:
852
  ### EXAMPLE ###
853
  {example_str}
854
 
855
- ### TASK ###
856
- User question: "{user_q}"
857
- Your single SQLite query response:
 
858
  ```sql
859
  SELECT
860
  """
861
- self._log(f"📏 Prompt 長度: {len(prompt)} 字符")
862
- # 不再需要複雜的長度截斷邏輯,因為 schema 已經被簡化
863
  return prompt
864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865
 
866
  def _generate_fallback_sql(self, prompt: str) -> str:
867
  """當模型不可用時的備用 SQL 生成"""
868
  prompt_lower = prompt.lower()
869
 
870
- # 簡單的關鍵詞匹配生成基本 SQL
871
  if "統計" in prompt or "數量" in prompt or "多少" in prompt:
872
  if "月" in prompt:
873
  return "SELECT strftime('%Y-%m', completed_time) as month, COUNT(*) as count FROM jobtimeline GROUP BY month ORDER BY month;"
@@ -875,100 +761,96 @@ SELECT
875
  return "SELECT applicant, COUNT(*) as count FROM tsr53sampledescription GROUP BY applicant ORDER BY count DESC;"
876
  else:
877
  return "SELECT COUNT(*) as total_count FROM jobtimeline WHERE completed_time IS NOT NULL;"
878
-
879
  elif "金額" in prompt or "總額" in prompt:
880
  return "SELECT SUM(amount) as total_amount FROM tsr53invoice;"
881
-
882
  elif "評級" in prompt or "pass" in prompt_lower or "fail" in prompt_lower:
883
  return "SELECT rating, COUNT(*) as count FROM tsr53sampledescription GROUP BY rating;"
884
-
885
  else:
886
  return "SELECT * FROM jobtimeline LIMIT 10;"
887
 
888
- def _validate_model_file(self, model_path):
889
- """驗證模型檔案完整性"""
890
- try:
891
- if not os.path.exists(model_path):
892
- return False
893
-
894
- # 檢查檔案大小(至少應該有幾MB)
895
- file_size = os.path.getsize(model_path)
896
- if file_size < 10 * 1024 * 1024: # 小於 10MB 可能有問題
897
- return False
898
-
899
- # 檢查 GGUF 檔案頭部
900
- with open(model_path, 'rb') as f:
901
- header = f.read(8)
902
- if not header.startswith(b'GGUF'):
903
- return False
904
-
905
- return True
906
- except Exception:
907
- return False
908
-
909
- # in class TextToSQLSystem:
910
-
911
  def process_question(self, question: str) -> Tuple[str, str]:
912
- """處理使用者問題 (V2 / 最終版)"""
913
  # 檢查緩存
914
  if question in self.query_cache:
915
- self._log("使用緩存結果")
916
  return self.query_cache[question]
917
 
918
  self.log_history = []
919
- self._log(f"處理問題: {question}")
920
-
921
-
922
- for attempt in range(2): # --- 新增:最多嘗試 2 次 ---
923
- self._log(f"🚀 開始第 {attempt + 1} 次嘗試...")
924
-
925
- # 1. 檢索相似範例 (第二次嘗試時不再重複)
926
- if attempt == 0:
927
- self._log("🔍 尋找相似範例...")
928
- examples = self.find_most_similar(question, FEW_SHOT_EXAMPLES_COUNT)
929
- if examples: self._log(f"✅ 找到 {len(examples)} 個相似範例")
930
-
931
- # 2. 建立提示詞
932
- self._log("📝 建立 Prompt...")
933
- prompt = self._build_prompt(question, examples)
 
 
 
934
 
935
- # --- 新增:如果是第二次嘗試,加入修正指令 ---
936
- if attempt > 0:
937
- correction_prompt = "\nYour previous attempt failed because you did not provide a valid SQL query. REMEMBER: ONLY output the SQL code inside a ```sql block. DO NOT write comments or explanations.\nSQL:\n```sql\nSELECT "
938
- # 將原本 prompt 的結尾替換成我們的修正指令
939
- prompt = prompt.rsplit("SQL:\n```sql", 1)[0] + correction_prompt
940
 
 
 
 
941
 
942
- # 3. 生成 AI 回應
943
- self._log("🧠 開始生成 AI 回應...")
944
- response = self.huggingface_api_call(prompt)
 
 
945
 
946
- # 4. 驗證與生成
947
- final_sql, status_message = self._validate_and_fix_sql(question, response)
948
 
949
- if final_sql:
950
- self._log(f"✅ 在第 {attempt + 1} 次嘗試成功!", "INFO")
951
- result = (final_sql, status_message)
952
- self.query_cache[question] = result # 緩存成功結果
953
- return result
954
 
955
- self._log(f"❌ 第 {attempt + 1} 次嘗試失敗。原因: {status_message}", "WARNING")
 
956
 
957
- # --- 如果兩次都失敗 ---
958
- self._log("❌ 所有嘗試均失敗,返回錯誤訊息。", "ERROR")
959
- final_fallback_message = "模型多次嘗試後仍無法生成有效的SQL。"
960
- return (final_fallback_message, "生成失敗")
961
 
962
- # ==================== Gradio 介面 ====================
 
963
  text_to_sql_system = TextToSQLSystem()
964
 
965
- def process_query(q: str):
966
- if not q.strip():
967
- return "", "等待輸入", "請輸入問題"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
968
 
969
  sql, status = text_to_sql_system.process_question(q)
970
- logs = "\n".join(text_to_sql_system.log_history[-10:]) # 只顯示最後10條日誌
971
-
972
  return sql, status, logs
973
 
974
  # 範例問題
@@ -980,36 +862,39 @@ examples = [
980
  "A組昨天完成了多少個測試項目?"
981
  ]
982
 
983
- with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
984
- gr.Markdown("# Text-to-SQL 智能助手")
985
- gr.Markdown("輸入自然語言問題,自動生成SQL查詢語句")
986
 
987
  with gr.Row():
988
  with gr.Column(scale=2):
989
- inp = gr.Textbox(lines=3, label="💬 您的問題", placeholder="例如:2024年每月完成多少份報告?")
990
- btn = gr.Button("🚀 生成 SQL", variant="primary")
991
  status = gr.Textbox(label="狀態", interactive=False)
 
 
992
 
993
  with gr.Column(scale=3):
994
- sql_out = gr.Code(label="🤖 生成的 SQL", language="sql", lines=8)
995
 
996
- with gr.Accordion("📋 處理日誌", open=False):
997
- logs = gr.Textbox(lines=8, label="日誌", interactive=False)
998
 
999
  # 範例區
1000
  gr.Examples(
1001
  examples=examples,
1002
  inputs=inp,
1003
- label="💡 點擊試用範例問題"
1004
  )
1005
 
1006
  # 綁定事件
1007
- btn.click(process_query, inputs=[inp], outputs=[sql_out, status, logs])
1008
- inp.submit(process_query, inputs=[inp], outputs=[sql_out, status, logs])
1009
 
1010
  if __name__ == "__main__":
1011
  demo.launch(
1012
  server_name="0.0.0.0",
1013
  server_port=7860,
1014
- share=False
 
1015
  )
 
4
  import json
5
  import torch
6
  import numpy as np
7
+ import psutil
8
+ import gc
9
+ import tempfile
10
  from datetime import datetime
11
  from datasets import load_dataset
12
  from huggingface_hub import hf_hub_download
 
14
  from typing import List, Dict, Tuple, Optional
15
  import faiss
16
  from functools import lru_cache
 
17
 
18
  # 使用 transformers 替代 sentence-transformers
19
  from transformers import AutoModel, AutoTokenizer
20
  import torch.nn.functional as F
21
 
22
+ # ==================== 配置參數 ====================
23
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
24
  GGUF_REPO_ID = "Paul720810/gguf-models"
25
+ GGUF_FILENAME = "qwen2-7b-instruct-sql-finetuned-stable.q4_k_m.gguf"
 
 
 
 
26
 
 
 
27
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
28
 
29
+ # 可配置 GPU(HF 免費方案通常只有 CPU)
30
+ USE_GPU = str(os.getenv("USE_GPU", "0")).lower() in {"1", "true", "yes", "y"}
31
+ try:
32
+ N_GPU_LAYERS = int(os.getenv("N_GPU_LAYERS", "0"))
33
+ except Exception:
34
+ N_GPU_LAYERS = 0
35
+ DEVICE = "cuda" if (USE_GPU and torch.cuda.is_available()) else "cpu"
36
+
37
+ # CPU 專用優化(可由環境變數覆蓋)
38
+ def _int_env(name: str, default_val: int) -> int:
39
+ try:
40
+ return int(os.getenv(name, str(default_val)))
41
+ except Exception:
42
+ return default_val
43
+
44
+ THREADS = _int_env("THREADS", min(4, os.cpu_count() or 2)) # llama.cpp 執行緒數
45
+ CTX = _int_env("CTX", 768 if DEVICE == "cpu" else 1024) # 上下文長度
46
+ MAX_TOKENS = _int_env("MAX_TOKENS", 60) # 生成 token 上限
47
+ FEW_SHOT_EXAMPLES_COUNT = _int_env("FEW_SHOT", 0 if DEVICE == "cpu" else 1)
48
+ ENABLE_INDEX = str(os.getenv("ENABLE_INDEX", "0" if DEVICE == "cpu" else "1")).lower() in {"1", "true", "yes", "y"}
49
+ EMBED_BATCH = _int_env("EMBED_BATCH", 8 if DEVICE == "cpu" else 16)
50
+
51
+ # 使用 /tmp 作為暫存目錄
52
+ TEMP_DIR = "/tmp/text_to_sql_cache"
53
+ os.makedirs(TEMP_DIR, exist_ok=True)
54
+
55
  print("=" * 60)
56
+ print("Text-to-SQL 系統啟動中 (HF 版本)...")
57
+ print(f"數據集: {DATASET_REPO_ID}")
58
+ print(f"嵌入模型: {EMBED_MODEL_NAME}")
59
+ print(f"設備: {DEVICE} (USE_GPU={USE_GPU}, N_GPU_LAYERS={N_GPU_LAYERS})")
60
+ print(f"THREADS={THREADS}, CTX={CTX}, MAX_TOKENS={MAX_TOKENS}, FEW_SHOT={FEW_SHOT_EXAMPLES_COUNT}, ENABLE_INDEX={ENABLE_INDEX}, EMBED_BATCH={EMBED_BATCH}")
61
+ print(f"暫存目錄: {TEMP_DIR}")
62
  print("=" * 60)
63
 
64
  # ==================== 工具函數 ====================
 
68
  def format_log(message: str, level: str = "INFO") -> str:
69
  return f"[{get_current_time()}] [{level.upper()}] {message}"
70
 
71
+ def check_memory_usage():
72
+ """檢查內存使用情況 - 簡化版本不依賴 psutil"""
73
+ try:
74
+ # 使用 /proc/meminfo 獲取內存信息 (Linux 環境)
75
+ with open('/proc/meminfo', 'r') as f:
76
+ lines = f.readlines()
77
 
78
+ mem_info = {}
79
+ for line in lines:
80
+ if line.startswith(('MemTotal:', 'MemFree:', 'MemAvailable:')):
81
+ key, value = line.split(':')
82
+ mem_info[key.strip()] = int(value.strip().split()[0])
83
 
84
+ total_gb = mem_info.get('MemTotal', 0) / (1024**2)
85
+ available_gb = mem_info.get('MemAvailable', mem_info.get('MemFree', 0)) / (1024**2)
86
+ used_percent = ((total_gb - available_gb) / total_gb * 100) if total_gb > 0 else 0
 
 
 
 
87
 
88
+ return f"內存使用率: {used_percent:.1f}% (可用: {available_gb:.1f}GB/{total_gb:.1f}GB)"
89
+ except:
90
+ # 如果無法讀取 /proc/meminfo,返回簡單信息
91
+ return "內存信息: 無法獲取詳細信息"
 
92
 
93
+ def parse_sql_from_response(response_text: str) -> Optional[str]:
94
+ """從模型輸出提取 SQL"""
95
+ if not response_text:
96
  return None
97
 
98
+ response_text = response_text.strip()
99
+
100
+ # 1. 先找 ```sql ... ```
101
+ match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
102
+ if match:
103
+ return match.group(1).strip()
104
+
105
+ # 2. 找任何 ``` 包圍的內容
106
+ match = re.search(r"```\s*\n?(.*?)\n?```", response_text, re.DOTALL)
107
+ if match:
108
+ sql_candidate = match.group(1).strip()
109
+ if sql_candidate.upper().startswith('SELECT'):
110
+ return sql_candidate
111
+
112
+ # 3. SQL 語句(更寬鬆的匹配)
113
+ match = re.search(r"(SELECT\s+.*?;)", response_text, re.DOTALL | re.IGNORECASE)
114
+ if match:
115
+ return match.group(1).strip()
116
+
117
+ # 4. 找沒有分號的 SQL
118
+ match = re.search(r"(SELECT\s+.*?)(?=\n\n|\n```|$|\n[^,\s])", response_text, re.DOTALL | re.IGNORECASE)
119
+ if match:
120
+ sql = match.group(1).strip()
121
+ if not sql.endswith(';'):
122
+ sql += ';'
123
+ return sql
124
+
125
+ # 5. 如果包含 SELECT,嘗試提取整行
126
+ if 'SELECT' in response_text.upper():
127
+ lines = response_text.split('\n')
128
+ for line in lines:
129
+ line = line.strip()
130
+ if line.upper().startswith('SELECT'):
131
+ if not line.endswith(';'):
132
+ line += ';'
133
+ return line
134
+
135
  return None
136
 
137
  # ==================== Text-to-SQL 核心類 ====================
 
140
  self.log_history = []
141
  self._log("初始化系統...")
142
  self.query_cache = {}
143
+ self.embed_device = DEVICE
144
+
145
+ # 檢查內存狀況
146
+ self._log(check_memory_usage())
147
 
148
  # 1. 載入嵌入模型
149
  self._log(f"載入嵌入模型: {embed_model_name}")
150
  self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
151
  self.embed_model = AutoModel.from_pretrained(embed_model_name)
152
+ try:
153
+ self.embed_model.to(self.embed_device)
154
+ self._log(f"嵌入模型設備: {self.embed_device}")
155
+ except Exception as e:
156
+ self._log(f"將嵌入模型移動到設備失敗: {e}", "WARNING")
157
+ self.embed_device = "cpu"
158
 
159
  # 2. 載入數據庫結構
160
  self.schema = self._load_schema()
 
162
  # 3. 載入數據集並建立索引
163
  self.dataset, self.faiss_index = self._load_and_index_dataset()
164
 
165
+ # 4. 載入 GGUF 模型(新增錯誤處理)
166
  self._load_gguf_model()
167
 
168
+ self._log("系統初始化完成")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ def _log(self, message: str, level: str = "INFO"):
171
+ self.log_history.append(format_log(message, level))
172
+ print(format_log(message, level))
173
 
174
  def _load_gguf_model(self):
175
+ """載入 GGUF 模型,針對 Paperspace 環境優化"""
176
  try:
177
+ self._log("開始下載 GGUF 模���到 /tmp...")
178
+
179
+ # 檢查模型是否已存在於 /tmp
180
+ model_cache_path = os.path.join(TEMP_DIR, GGUF_FILENAME)
181
+
182
+ if os.path.exists(model_cache_path) and self._validate_model_file(model_cache_path):
183
+ self._log(f"發現快取模型: {model_cache_path}")
184
+ model_path = model_cache_path
185
+ else:
186
+ self._log("下載新模型...")
187
+ model_path = hf_hub_download(
188
+ repo_id=GGUF_REPO_ID,
189
+ filename=GGUF_FILENAME,
190
+ repo_type="dataset",
191
+ cache_dir=TEMP_DIR,
192
+ resume_download=True
193
+ )
194
+ self._log(f"模型下載完成: {model_path}")
195
+
196
+ # 檢查內存情況
197
+ self._log(check_memory_usage())
198
 
199
+ # 使用 CPU 友好的參數載入模型(可選 GPU layers)
200
+ ngl = N_GPU_LAYERS if (DEVICE == "cuda" and N_GPU_LAYERS > 0) else 0
201
+ self._log(f"載入 GGUF 模型 (n_gpu_layers={ngl}, n_threads={THREADS}, n_ctx={CTX})...")
202
+ self.llm = Llama(
203
  model_path=model_path,
204
+ n_ctx=CTX, # 上下文長度(CPU 默認更小)
205
+ n_threads=THREADS, # 使用多執行緒
206
+ n_batch=256, # 批處理大小
207
+ verbose=False,
208
+ n_gpu_layers=ngl, # 可選 GPU 加速
209
+ use_mmap=True, # 使用內存映射減少內存占用
210
+ use_mlock=False, # 不鎖定內存
211
+ low_vram=True # 啟用低內存模式
212
  )
213
 
214
+ # 簡單測試模型
215
+ test_result = self.llm("SELECT", max_tokens=3)
216
+ self._log("GGUF 模型載入成功")
217
+
218
+ # 再次檢查內存
219
+ self._log(check_memory_usage())
220
 
221
  except Exception as e:
222
+ self._log(f"GGUF 載入失敗: {e}", "ERROR")
223
+ self._log("系統將無法生成 SQL。請檢查模型檔案或內存情況。", "CRITICAL")
224
  self.llm = None
225
 
226
+ def _validate_model_file(self, model_path):
227
+ """驗證模型檔案完整性"""
228
  try:
229
+ if not os.path.exists(model_path):
230
+ return False
 
 
 
231
 
232
+ # 檢查檔案大小(至少應該有幾百MB)
233
+ file_size = os.path.getsize(model_path)
234
+ if file_size < 50 * 1024 * 1024: # 小於 50MB 可能有問題
235
+ return False
 
 
 
236
 
237
+ # 檢查 GGUF 檔案頭部
238
+ with open(model_path, 'rb') as f:
239
+ header = f.read(8)
240
+ if not header.startswith(b'GGUF'):
241
+ return False
242
 
243
+ return True
244
+ except Exception:
245
  return False
246
 
247
+ def huggingface_api_call(self, prompt: str) -> str:
248
+ """調用 GGUF 模型,並加入詳細的原始輸出日誌"""
249
+ if self.llm is None:
250
+ self._log("模型未載入,返回 fallback SQL。", "ERROR")
251
+ return self._generate_fallback_sql(prompt)
252
+
253
  try:
254
+ # 清理垃圾收集
255
+ gc.collect()
 
 
 
 
 
 
 
 
 
 
 
256
 
257
+ output = self.llm(
258
+ prompt,
259
+ max_tokens=MAX_TOKENS, # 生成長度可配置
 
 
 
 
 
260
  temperature=0.1,
261
  top_p=0.9,
262
+ echo=False,
263
+ stop=["```", ";", "\n\n", "</s>"],
264
  )
265
 
266
+ self._log(f"模型原始輸出: {str(output)[:200]}...", "DEBUG")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
+ if output and "choices" in output and len(output["choices"]) > 0:
269
+ generated_text = output["choices"][0]["text"]
270
+ self._log(f"提取出的生成文本: {generated_text.strip()}", "DEBUG")
271
+ return generated_text.strip()
272
+ else:
273
+ self._log("模型的原始輸出格式不正確或為空。", "ERROR")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  return ""
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  except Exception as e:
277
+ self._log(f"模型生成過程中發生嚴重錯誤: {e}", "CRITICAL")
278
+ import traceback
279
+ self._log(traceback.format_exc(), "DEBUG")
280
+ return ""
 
 
281
 
282
  def _load_schema(self) -> Dict:
283
  """載入數據庫結構"""
 
285
  schema_path = hf_hub_download(
286
  repo_id=DATASET_REPO_ID,
287
  filename="sqlite_schema_FULL.json",
288
+ repo_type="dataset",
289
+ cache_dir=TEMP_DIR
290
  )
291
  with open(schema_path, "r", encoding="utf-8") as f:
292
  schema_data = json.load(f)
293
 
294
+ self._log(f"Schema 載入成功,包含 {len(schema_data)} 個表格:")
 
295
  for table_name, columns in schema_data.items():
296
  self._log(f" - {table_name}: {len(columns)} 個欄位")
 
 
 
297
 
298
+ self._log("數據庫結構載入完成")
299
  return schema_data
300
 
301
  except Exception as e:
302
+ self._log(f"載入 schema 失敗: {e}", "ERROR")
303
  return {}
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  def _encode_texts(self, texts):
306
  """編碼文本為嵌入向量"""
307
  if isinstance(texts, str):
308
  texts = [texts]
 
309
  inputs = self.embed_tokenizer(texts, padding=True, truncation=True,
310
+ return_tensors="pt", max_length=512)
311
+ # 移動到對應設備
312
+ try:
313
+ inputs = {k: v.to(self.embed_device) for k, v in inputs.items()}
314
+ except Exception:
315
+ pass
316
 
317
  with torch.no_grad():
318
  outputs = self.embed_model(**inputs)
319
 
320
  # 使用平均池化
321
  embeddings = outputs.last_hidden_state.mean(dim=1)
322
+ return embeddings.detach().cpu()
323
 
324
  def _load_and_index_dataset(self):
325
  """載入數據集並建立 FAISS 索引"""
326
  try:
327
+ if not ENABLE_INDEX:
328
+ self._log("已禁用相似範例索引(ENABLE_INDEX=0)。啟動更快,將不使用 few-shot。")
329
+ return None, None
330
 
331
+ dataset = load_dataset(
332
+ DATASET_REPO_ID,
333
+ data_files="training_data.jsonl",
334
+ split="train",
335
+ cache_dir=TEMP_DIR
336
+ )
337
 
338
+ # 過濾不完整樣本
339
+ original_count = len(dataset)
340
  dataset = dataset.filter(
341
  lambda ex: isinstance(ex.get("messages"), list)
342
  and len(ex["messages"]) >= 2
 
346
  )
347
  )
348
 
349
+ self._log(f"資料集清理: 原始 {original_count} 筆, 過濾後 {len(dataset)} 筆")
 
 
 
350
 
351
  if len(dataset) == 0:
352
  self._log("清理後資料集為空,無法建立索引。", "ERROR")
 
355
  corpus = [item['messages'][0]['content'] for item in dataset]
356
  self._log(f"正在編碼 {len(corpus)} 個問題...")
357
 
358
+ # 批量編碼以節省內存
359
  embeddings_list = []
360
+ batch_size = EMBED_BATCH # 可配置的批次大小(CPU 預設更小)
361
 
362
  for i in range(0, len(corpus), batch_size):
363
  batch_texts = corpus[i:i+batch_size]
364
  batch_embeddings = self._encode_texts(batch_texts)
365
  embeddings_list.append(batch_embeddings)
366
+
367
+ # 清理內存
368
+ if i % (batch_size * 4) == 0:
369
+ gc.collect()
370
+
371
  self._log(f"已編碼 {min(i+batch_size, len(corpus))}/{len(corpus)}")
372
 
373
  all_embeddings = torch.cat(embeddings_list, dim=0).numpy()
 
376
  index = faiss.IndexFlatIP(all_embeddings.shape[1])
377
  index.add(all_embeddings.astype('float32'))
378
 
379
+ # 清理內存
380
+ del embeddings_list, all_embeddings
381
+ gc.collect()
382
+
383
+ self._log("向量索引建立完成")
384
  return dataset, index
385
 
386
  except Exception as e:
387
+ self._log(f"載入數據失敗: {e}", "ERROR")
388
  return None, None
389
 
390
  def _identify_relevant_tables(self, question: str) -> List[str]:
 
415
 
416
  return relevant_tables[:3] # 最多返回3個相關表格
417
 
 
 
418
  def _format_relevant_schema(self, table_names: List[str]) -> str:
419
+ """生成一個簡化的 Schema 字符串"""
 
 
420
  if not self.schema:
421
  return "No schema available.\n"
422
 
 
436
  formatted = ""
437
  for table in real_table_names:
438
  if table in self.schema:
 
439
  formatted += f"Table: {table}\n"
440
  cols_str = []
441
+ # 只顯示前 8 個關鍵欄位以節省內存
442
+ for col in self.schema[table][:8]:
443
  col_name = col['name']
444
  col_type = col['type']
445
+ cols_str.append(f"{col_name} ({col_type})")
 
 
 
 
 
446
  formatted += f"Columns: {', '.join(cols_str)}\n\n"
447
 
448
  return formatted.strip()
449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  def find_most_similar(self, question: str, top_k: int) -> List[Dict]:
451
  """使用 FAISS 快速檢索相似問題"""
452
  if self.faiss_index is None or self.dataset is None:
 
466
  if len(results) >= top_k:
467
  break
468
 
469
+ idx = int(idx)
470
+ if idx >= len(self.dataset):
 
 
471
  continue
472
 
473
  item = self.dataset[idx]
 
474
  if not isinstance(item.get('messages'), list) or len(item['messages']) < 2:
475
  continue
476
+
477
  q_content = (item['messages'][0].get('content') or '').strip()
478
  a_content = (item['messages'][1].get('content') or '').strip()
479
  if not q_content or not a_content:
 
496
  return results
497
 
498
  except Exception as e:
499
+ self._log(f"檢索失敗: {e}", "ERROR")
500
  return []
501
 
 
 
502
  def _build_prompt(self, user_q: str, examples: List[Dict]) -> str:
503
+ """建立簡化的提示詞"""
 
 
504
  relevant_tables = self._identify_relevant_tables(user_q)
 
 
505
  schema_str = self._format_relevant_schema(relevant_tables)
506
 
507
  example_str = "No example available."
 
509
  best_example = examples[0]
510
  example_str = f"Question: {best_example['question']}\nSQL:\n```sql\n{best_example['sql']}\n```"
511
 
512
+ # 簡化的 prompt,減少 token 使用
513
+ prompt = f"""### TASK ###
514
+ Generate SQLite query for the question below.
515
 
516
  ### SCHEMA ###
517
  {schema_str}
 
519
  ### EXAMPLE ###
520
  {example_str}
521
 
522
+ ### QUESTION ###
523
+ {user_q}
524
+
525
+ SQL:
526
  ```sql
527
  SELECT
528
  """
 
 
529
  return prompt
530
 
531
+ def _rule_based_sql(self, question: str) -> Optional[str]:
532
+ """規則先行:對常見查詢用模板直接生成 SQL,繞過 LLM。"""
533
+ q = (question or "").strip()
534
+ q_lower = q.lower()
535
+
536
+ # 兩年比較(完成數量、每月)
537
+ m = re.search(r"(20\d{2}).{0,6}(?:與|和|跟)\s*(20\d{2}).{0,10}(比較|對比).{0,10}(完成|報告|數量|件|工單)", q)
538
+ if m:
539
+ y1, y2 = m.group(1), m.group(2)
540
+ return (
541
+ "SELECT strftime('%Y-%m', completed_time) AS month, "
542
+ f"SUM(CASE WHEN strftime('%Y', completed_time)='{y1}' THEN 1 ELSE 0 END) AS count_{y1}, "
543
+ f"SUM(CASE WHEN strftime('%Y', completed_time)='{y2}' THEN 1 ELSE 0 END) AS count_{y2} "
544
+ "FROM jobtimeline "
545
+ f"WHERE strftime('%Y', completed_time) IN ('{y1}','{y2}') "
546
+ "GROUP BY month ORDER BY month;"
547
+ )
548
+
549
+ # 指定年份每月完成數量
550
+ m = re.search(r"(20\d{2})年.*每月.*(完成|報告|數量|件|工單)", q)
551
+ if m:
552
+ year = m.group(1)
553
+ return (
554
+ "SELECT strftime('%Y-%m', completed_time) AS month, COUNT(*) AS count "
555
+ "FROM jobtimeline "
556
+ f"WHERE strftime('%Y', completed_time)='{year}' "
557
+ "GROUP BY month ORDER BY month;"
558
+ )
559
+
560
+ # 評級分布(Pass/Fail)
561
+ if ("評級" in q) or ("pass" in q_lower) or ("fail" in q_lower):
562
+ return "SELECT rating, COUNT(*) AS count FROM tsr53sampledescription GROUP BY rating;"
563
+
564
+ # 金額最高 Top N(預設 10)
565
+ m = re.search(r"金額.*?(?:最高|前|top)\s*(\d+)?", q_lower)
566
+ if m:
567
+ n = m.group(1) or "10"
568
+ return f"SELECT * FROM tsr53invoice ORDER BY amount DESC LIMIT {n};"
569
+
570
+ # 客戶工作單數量最多 Top N
571
+ m = re.search(r"客戶.*?(?:最多|top|前)\s*(\d+)?", q_lower)
572
+ if m:
573
+ n = m.group(1) or "10"
574
+ return f"SELECT applicant, COUNT(*) AS count FROM tsr53sampledescription GROUP BY applicant ORDER BY count DESC LIMIT {n};"
575
+
576
+ # 昨天完成多少
577
+ if "昨天" in q:
578
+ return (
579
+ "SELECT COUNT(*) AS count FROM jobtimeline "
580
+ "WHERE date(completed_time)=date('now','-1 day');"
581
+ )
582
+
583
+ return None
584
+
585
+ def _finalize_sql(self, sql_text: str, status: str) -> Tuple[str, str]:
586
+ """最終整理 SQL:補分號、去除多餘空白並回傳 (sql, 狀態)。"""
587
+ try:
588
+ sql_clean = (sql_text or "").strip()
589
+ if sql_clean and not sql_clean.endswith(";"):
590
+ sql_clean += ";"
591
+ return sql_clean, status
592
+ except Exception as e:
593
+ self._log(f"最終整理 SQL 失敗: {e}", "ERROR")
594
+ return (sql_text or ""), status
595
+
596
+ def _validate_and_fix_sql(self, question: str, raw_response: str) -> Tuple[Optional[str], str]:
597
+ """
598
+ (V29 / 穩健正則 + 智能計數) 多層次 SQL 生成:
599
+ 1) 嘗試規則/模板動態組合
600
+ 2) 失敗則解析 AI 輸出並做方言/Schema 修正
601
+ 回傳: (sql 或 None, 狀態描述)
602
+ """
603
+ q = question or ""
604
+ q_lower = q.lower()
605
+
606
+ # 先嘗試內建的規則先行器
607
+ rb = self._rule_based_sql(q)
608
+ if rb:
609
+ self._log("_validate_and_fix_sql 命中規則模板")
610
+ return self._finalize_sql(rb, "規則生成")
611
+
612
+ # 統一實體識別(簡化版)
613
+ entity_match_data = None
614
+ entity_patterns = [
615
+ {'pattern': r"(買家|买家|buyer)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.BuyerID', 'type': '買家ID'},
616
+ {'pattern': r"(申請方|申请方|申請廠商|申请厂商|applicant)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.ApplicantID', 'type': '申請方ID'},
617
+ {'pattern': r"(付款方|付款厂商|invoiceto)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.InvoiceToID', 'type': '付款方ID'},
618
+ {'pattern': r"(代理商|agent)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.AgentID', 'type': '代理商ID'},
619
+ {'pattern': r"(買家|买家|buyer|客戶)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.BuyerName', 'type': '買家'},
620
+ {'pattern': r"(申請方|申请方|申請廠商|申请厂商|applicant)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.ApplicantName', 'type': '申請方'},
621
+ {'pattern': r"(付款方|付款厂商|invoiceto)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.InvoiceToName', 'type': '付款方'},
622
+ {'pattern': r"(代理商|agent)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.AgentName', 'type': '代理商'},
623
+ {'pattern': r"\b([A-Z]\d{4}[A-Z])\b", 'column': 'sd.ApplicantID', 'type': 'ID'}
624
+ ]
625
+ for p in entity_patterns:
626
+ m = re.search(p['pattern'], q, re.IGNORECASE)
627
+ if m:
628
+ entity_value = m.group(2) if len(m.groups()) > 1 else m.group(1)
629
+ entity_match_data = {"type": p['type'], "name": entity_value.strip().upper(), "column": p['column']}
630
+ break
631
+
632
+ # 模組化意圖偵測與動態 SQL 組合
633
+ intents: Dict[str, str] = {}
634
+ sql = {
635
+ 'select': [], 'from': '', 'joins': [], 'where': [],
636
+ 'group_by': [], 'order_by': [], 'log_parts': []
637
+ }
638
+
639
+ # 動作意圖:count / list
640
+ if any(kw in q_lower for kw in ['幾份', '多少', '數量', '總數', 'how many', 'count']):
641
+ intents['action'] = 'count'
642
+ if ("測試項目" in q) or ("test item" in q_lower):
643
+ sql['select'].append("COUNT(jip.ItemCode) AS item_count")
644
+ sql['log_parts'].append("測試項目總數")
645
+ else:
646
+ sql['select'].append("COUNT(DISTINCT jt.JobNo) AS report_count")
647
+ sql['log_parts'].append("報告總數")
648
+ elif any(kw in q_lower for kw in ['報告號碼', '報告清單', '列出報告', 'report number', 'list of reports']):
649
+ intents['action'] = 'list'
650
+ sql['select'].append("jt.JobNo, jt.ReportAuthorization")
651
+ sql['order_by'].append("jt.ReportAuthorization DESC")
652
+ sql['log_parts'].append("報告列表")
653
+
654
+ # 時間意圖:年/月
655
+ ym = re.search(r'(\d{4})\s*年?', q)
656
+ mm = re.search(r'(\d{1,2})\s*月', q)
657
+ if ym:
658
+ year = ym.group(1)
659
+ sql['where'].append(f"strftime('%Y', jt.ReportAuthorization) = '{year}'")
660
+ sql['log_parts'].append(f"{year}年")
661
+ if mm:
662
+ month = mm.group(1).zfill(2)
663
+ sql['where'].append(f"strftime('%m', jt.ReportAuthorization) = '{month}'")
664
+ sql['log_parts'].append(f"{month}月")
665
+
666
+ # 實體意圖
667
+ if entity_match_data:
668
+ if "TSR53SampleDescription" not in " ".join(sql['joins']):
669
+ sql['joins'].append("JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo")
670
+ entity_name, column_name = entity_match_data['name'], entity_match_data['column']
671
+ match_op = '=' if column_name.endswith('ID') else 'LIKE'
672
+ entity_val = f"'%{entity_name}%'" if match_op == 'LIKE' else f"'{entity_name}'"
673
+ sql['where'].append(f"{column_name} {match_op} {entity_val}")
674
+ sql['log_parts'].append(entity_match_data['type'] + ":" + entity_name)
675
+ if intents.get('action') == 'list':
676
+ sql['select'].append("sd.BuyerName")
677
+
678
+ # 評級意圖
679
+ if ('fail' in q_lower) or ('失敗' in q_lower):
680
+ if "TSR53SampleDescription" not in " ".join(sql['joins']):
681
+ sql['joins'].append("JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo")
682
+ sql['where'].append("sd.OverallRating = 'Fail'")
683
+ sql['log_parts'].append("Fail")
684
+ elif ('pass' in q_lower) or ('通過' in q_lower):
685
+ if "TSR53SampleDescription" not in " ".join(sql['joins']):
686
+ sql['joins'].append("JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo")
687
+ sql['where'].append("sd.OverallRating = 'Pass'")
688
+ sql['log_parts'].append("Pass")
689
+
690
+ # 實驗組 (LabGroup)
691
+ lab_group_mapping = {'A': 'TA', 'B': 'TB', 'C': 'TC', 'D': 'TD', 'E': 'TE', 'Y': 'TY'}
692
+ lgm = re.search(r'([A-Z]{1,2})組', q, re.IGNORECASE)
693
+ if lgm:
694
+ user_group = lgm.group(1).upper()
695
+ db_group = lab_group_mapping.get(user_group, user_group)
696
+ sql['joins'].append("JOIN JobItemsInProgress AS jip ON jt.JobNo = jip.JobNo")
697
+ sql['where'].append(f"jip.LabGroup = '{db_group}'")
698
+ sql['log_parts'].append(f"{user_group}組(->{db_group})")
699
+
700
+ # 若動作已決定,組裝模板 SQL
701
+ if 'action' in intents:
702
+ sql['from'] = "FROM JobTimeline AS jt"
703
+ if sql['where']:
704
+ sql['where'].insert(0, "jt.ReportAuthorization IS NOT NULL")
705
+ select_clause = "SELECT " + ", ".join(sorted(list(set(sql['select'])))) if sql['select'] else "SELECT *"
706
+ from_clause = sql['from']
707
+ joins_clause = " ".join(sql['joins'])
708
+ where_clause = ("WHERE " + " AND ".join(sql['where'])) if sql['where'] else ""
709
+ orderby_clause = ("ORDER BY " + ", ".join(sql['order_by'])) if sql['order_by'] else ""
710
+ template_sql = f"{select_clause} {from_clause} {joins_clause} {where_clause} {orderby_clause};"
711
+ query_log = " ".join(sql['log_parts'])
712
+ self._log(f"🔄 偵測到組合意圖【{query_log}】,啟用動態模板。")
713
+ return self._finalize_sql(template_sql, f"模板覆寫: {query_log} 查詢")
714
+
715
+ # 第二層:解析 AI 輸出並修正
716
+ self._log("未觸發任何模板,嘗試解析並修正 AI 輸出…")
717
+ parsed_sql = parse_sql_from_response(raw_response)
718
+ if not parsed_sql:
719
+ self._log(f"❌ 未能從模型回應中解析出任何 SQL。原始回應: {raw_response}", "ERROR")
720
+ return None, f"無法解析SQL。原始回應:\n{raw_response}"
721
+
722
+ self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG")
723
+ fixed_sql = " " + parsed_sql.strip() + " "
724
+ fixes_applied = []
725
+
726
+ # 方言修正
727
+ dialect_corrections = {r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)"}
728
+ for pat, rep in dialect_corrections.items():
729
+ if re.search(pat, fixed_sql, re.IGNORECASE):
730
+ fixed_sql = re.sub(pat, rep, fixed_sql, flags=re.IGNORECASE)
731
+ fixes_applied.append(f"修正方言: {pat}")
732
+
733
+ # Schema 名稱修正(常見別名 => 真實欄位)
734
+ schema_map = {
735
+ 'TSR53Report':'TSR53SampleDescription',
736
+ 'TSR53InvoiceReportNo':'JobNo',
737
+ 'TSR53ReportNo':'JobNo',
738
+ 'TSR53InvoiceNo':'JobNo',
739
+ 'TSR53InvoiceCreditNoteNo':'InvoiceCreditNoteNo',
740
+ 'TSR53InvoiceLocalAmount':'LocalAmount',
741
+ 'Status':'OverallRating',
742
+ 'ReportStatus':'OverallRating'
743
+ }
744
+ for wrong, correct in schema_map.items():
745
+ pat = r'\b' + re.escape(wrong) + r'\b'
746
+ if re.search(pat, fixed_sql, re.IGNORECASE):
747
+ fixed_sql = re.sub(pat, correct, fixed_sql, flags=re.IGNORECASE)
748
+ fixes_applied.append(f"映射 Schema: '{wrong}' -> '{correct}'")
749
+
750
+ status = "AI 生成並成功修正" if fixes_applied else "AI 生成且無需修正"
751
+ return self._finalize_sql(fixed_sql, status)
752
 
753
  def _generate_fallback_sql(self, prompt: str) -> str:
754
  """當模型不可用時的備用 SQL 生成"""
755
  prompt_lower = prompt.lower()
756
 
 
757
  if "統計" in prompt or "數量" in prompt or "多少" in prompt:
758
  if "月" in prompt:
759
  return "SELECT strftime('%Y-%m', completed_time) as month, COUNT(*) as count FROM jobtimeline GROUP BY month ORDER BY month;"
 
761
  return "SELECT applicant, COUNT(*) as count FROM tsr53sampledescription GROUP BY applicant ORDER BY count DESC;"
762
  else:
763
  return "SELECT COUNT(*) as total_count FROM jobtimeline WHERE completed_time IS NOT NULL;"
 
764
  elif "金額" in prompt or "總額" in prompt:
765
  return "SELECT SUM(amount) as total_amount FROM tsr53invoice;"
 
766
  elif "評級" in prompt or "pass" in prompt_lower or "fail" in prompt_lower:
767
  return "SELECT rating, COUNT(*) as count FROM tsr53sampledescription GROUP BY rating;"
 
768
  else:
769
  return "SELECT * FROM jobtimeline LIMIT 10;"
770
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
771
  def process_question(self, question: str) -> Tuple[str, str]:
772
+ """處理使用者問題"""
773
  # 檢查緩存
774
  if question in self.query_cache:
775
+ self._log("使用緩存結果")
776
  return self.query_cache[question]
777
 
778
  self.log_history = []
779
+ self._log(f"處理問題: {question}")
780
+ self._log(check_memory_usage())
781
+
782
+ # 0. 規則先行(命中則直接返回)
783
+ rb = self._rule_based_sql(question)
784
+ if rb:
785
+ self._log("規則命中,直接生成 SQL(跳過 LLM)")
786
+ self._log(f"最終 SQL: {rb}")
787
+ result = (rb, "規則生成")
788
+ self.query_cache[question] = result
789
+ gc.collect()
790
+ return result
791
+
792
+ # 1. 檢索相似範例
793
+ self._log("尋找相似範例...")
794
+ examples = self.find_most_similar(question, FEW_SHOT_EXAMPLES_COUNT)
795
+ if examples:
796
+ self._log(f"找到 {len(examples)} 個相似範例")
797
 
798
+ # 2. 建立提示詞
799
+ self._log("建立 Prompt...")
800
+ prompt = self._build_prompt(question, examples)
 
 
801
 
802
+ # 3. 生成 AI 回應
803
+ self._log("開始生成 AI 回應...")
804
+ response = self.huggingface_api_call(prompt)
805
 
806
+ # 4. 驗證/修正 SQL
807
+ fixed_sql, status_message = self._validate_and_fix_sql(question, response)
808
+ if not fixed_sql:
809
+ fixed_sql = "SELECT '未能生成有效的SQL,請嘗試換個問題描述';"
810
+ status_message = status_message or "生成失敗"
811
 
812
+ self._log(f"最終 SQL: {fixed_sql}")
813
+ result = (fixed_sql, status_message)
814
 
815
+ # 緩存結果
816
+ self.query_cache[question] = result
 
 
 
817
 
818
+ # 清理內存
819
+ gc.collect()
820
 
821
+ return result
 
 
 
822
 
823
+ # ==================== Gradio 介面與 API ====================
824
+ print("正在初始化 Text-to-SQL 系統...")
825
  text_to_sql_system = TextToSQLSystem()
826
 
827
+ def process_query(q: str, prompt_override: str = ""):
828
+ if not (q or prompt_override).strip():
829
+ return "", "等待輸入", "請輸入問題或提供 prompt_override"
830
+
831
+ # 若提供 prompt_override:
832
+ if prompt_override and prompt_override.strip():
833
+ po = prompt_override.strip()
834
+ # 如果 override 本身就是 SQL,直接回傳
835
+ if po.upper().startswith("SELECT"):
836
+ if not po.strip().endswith(";"):
837
+ po = po.strip() + ";"
838
+ text_to_sql_system._log("使用 prompt_override 直接回傳 SQL")
839
+ logs = "\n".join(text_to_sql_system.log_history[-15:])
840
+ return po, "override", logs
841
+ # 否則當作完整 prompt 丟給 LLM
842
+ text_to_sql_system._log("使用 prompt_override 直接調用 LLM")
843
+ response = text_to_sql_system.huggingface_api_call(po)
844
+ fixed_sql, status_message = text_to_sql_system._validate_and_fix_sql(q or "", response)
845
+ if not fixed_sql:
846
+ fixed_sql = text_to_sql_system._generate_fallback_sql(po)
847
+ status_message = status_message or "override 回退"
848
+ text_to_sql_system._log(f"最終 SQL: {fixed_sql}")
849
+ logs = "\n".join(text_to_sql_system.log_history[-15:])
850
+ return fixed_sql, "override", logs
851
 
852
  sql, status = text_to_sql_system.process_question(q)
853
+ logs = "\n".join(text_to_sql_system.log_history[-15:]) # 顯示最後15條日誌
 
854
  return sql, status, logs
855
 
856
  # 範例問題
 
862
  "A組昨天完成了多少個測試項目?"
863
  ]
864
 
865
+ with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手 (HF Space)") as demo:
866
+ gr.Markdown("# Text-to-SQL 智能助手 (Hugging Face Space)")
867
+ gr.Markdown("輸入自然語言問題,自動生成SQL查詢語句。使用 /tmp 暫存,每次啟動重新下載模型。支援桌面端透過 /predict API 呼叫。")
868
 
869
  with gr.Row():
870
  with gr.Column(scale=2):
871
+ inp = gr.Textbox(lines=3, label="您的問題", placeholder="例如:2024年每月完成多少份報告?")
872
+ btn = gr.Button("生成 SQL", variant="primary")
873
  status = gr.Textbox(label="狀態", interactive=False)
874
+ # 隱藏的 prompt_override 供桌面端呼叫
875
+ prompt_override = gr.Textbox(label="prompt_override", visible=False)
876
 
877
  with gr.Column(scale=3):
878
+ sql_out = gr.Code(label="生成的 SQL", language="sql", lines=8)
879
 
880
+ with gr.Accordion("處理日誌", open=False):
881
+ logs = gr.Textbox(lines=10, label="日誌", interactive=False)
882
 
883
  # 範例區
884
  gr.Examples(
885
  examples=examples,
886
  inputs=inp,
887
+ label="點擊試用範例問題"
888
  )
889
 
890
  # 綁定事件
891
+ btn.click(process_query, inputs=[inp, prompt_override], outputs=[sql_out, status, logs], api_name="/predict")
892
+ inp.submit(process_query, inputs=[inp, prompt_override], outputs=[sql_out, status, logs])
893
 
894
  if __name__ == "__main__":
895
  demo.launch(
896
  server_name="0.0.0.0",
897
  server_port=7860,
898
+ share=True,
899
+ show_error=True
900
  )