Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,6 +4,9 @@ import re
|
|
| 4 |
import json
|
| 5 |
import torch
|
| 6 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
| 7 |
from datetime import datetime
|
| 8 |
from datasets import load_dataset
|
| 9 |
from huggingface_hub import hf_hub_download
|
|
@@ -11,30 +14,51 @@ from llama_cpp import Llama
|
|
| 11 |
from typing import List, Dict, Tuple, Optional
|
| 12 |
import faiss
|
| 13 |
from functools import lru_cache
|
| 14 |
-
import re
|
| 15 |
|
| 16 |
# 使用 transformers 替代 sentence-transformers
|
| 17 |
from transformers import AutoModel, AutoTokenizer
|
| 18 |
import torch.nn.functional as F
|
| 19 |
|
| 20 |
-
# ====================
|
| 21 |
DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
|
| 22 |
GGUF_REPO_ID = "Paul720810/gguf-models"
|
| 23 |
-
GGUF_FILENAME = "qwen2
|
| 24 |
-
#GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q8_0.gguf"
|
| 25 |
-
|
| 26 |
-
# 添加這一行:你的原始微調模型路徑
|
| 27 |
-
FINETUNED_MODEL_PATH = "Paul720810/qwen2.5-coder-1.5b-sql-finetuned" # ← 新增這行
|
| 28 |
|
| 29 |
-
FEW_SHOT_EXAMPLES_COUNT = 2
|
| 30 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
print("=" * 60)
|
| 34 |
-
print("
|
| 35 |
-
print(f"
|
| 36 |
-
print(f"
|
| 37 |
-
print(f"
|
|
|
|
|
|
|
| 38 |
print("=" * 60)
|
| 39 |
|
| 40 |
# ==================== 工具函數 ====================
|
|
@@ -44,57 +68,70 @@ def get_current_time():
|
|
| 44 |
def format_log(message: str, level: str = "INFO") -> str:
|
| 45 |
return f"[{get_current_time()}] [{level.upper()}] {message}"
|
| 46 |
|
| 47 |
-
def
|
| 48 |
-
"""
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
for block in code_blocks:
|
| 58 |
-
b = block.strip()
|
| 59 |
-
if 'select' in b.lower():
|
| 60 |
-
candidates.append(b)
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
candidates.append(m.group(0).strip())
|
| 67 |
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
return None
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
return None
|
| 99 |
|
| 100 |
# ==================== Text-to-SQL 核心類 ====================
|
|
@@ -103,15 +140,21 @@ class TextToSQLSystem:
|
|
| 103 |
self.log_history = []
|
| 104 |
self._log("初始化系統...")
|
| 105 |
self.query_cache = {}
|
| 106 |
-
self.
|
| 107 |
-
|
|
|
|
|
|
|
| 108 |
|
| 109 |
# 1. 載入嵌入模型
|
| 110 |
self._log(f"載入嵌入模型: {embed_model_name}")
|
| 111 |
self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
|
| 112 |
self.embed_model = AutoModel.from_pretrained(embed_model_name)
|
| 113 |
-
|
| 114 |
-
self.embed_model
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
# 2. 載入數據庫結構
|
| 117 |
self.schema = self._load_schema()
|
|
@@ -119,220 +162,122 @@ class TextToSQLSystem:
|
|
| 119 |
# 3. 載入數據集並建立索引
|
| 120 |
self.dataset, self.faiss_index = self._load_and_index_dataset()
|
| 121 |
|
| 122 |
-
# 4. 載入 GGUF
|
| 123 |
self._load_gguf_model()
|
| 124 |
|
| 125 |
-
self._log("
|
| 126 |
-
# 載入數據庫結構
|
| 127 |
-
self.schema = self._load_schema()
|
| 128 |
-
|
| 129 |
-
# 暫時添加:打印 schema 信息
|
| 130 |
-
if self.schema:
|
| 131 |
-
print("=" * 50)
|
| 132 |
-
print("數據庫 Schema 信息:")
|
| 133 |
-
for table_name, columns in self.schema.items():
|
| 134 |
-
print(f"\n表格: {table_name}")
|
| 135 |
-
print(f"欄位數: {len(columns)}")
|
| 136 |
-
print("欄位列表:")
|
| 137 |
-
for col in columns[:5]: # 只顯示前5個
|
| 138 |
-
print(f" - {col['name']} ({col['type']})")
|
| 139 |
-
print("=" * 50)
|
| 140 |
|
| 141 |
-
|
|
|
|
|
|
|
| 142 |
|
| 143 |
def _load_gguf_model(self):
|
| 144 |
-
"""載入 GGUF
|
| 145 |
try:
|
| 146 |
-
self._log("
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
-
#
|
| 154 |
-
|
|
|
|
|
|
|
| 155 |
model_path=model_path,
|
| 156 |
-
n_ctx=
|
| 157 |
-
n_threads=
|
| 158 |
-
n_batch=
|
| 159 |
-
verbose=False,
|
| 160 |
-
n_gpu_layers=
|
|
|
|
|
|
|
|
|
|
| 161 |
)
|
| 162 |
|
| 163 |
-
#
|
| 164 |
-
self.
|
| 165 |
-
self.
|
| 166 |
-
|
|
|
|
|
|
|
| 167 |
|
| 168 |
except Exception as e:
|
| 169 |
-
self._log(f"
|
| 170 |
-
self._log("系統將無法生成 SQL
|
| 171 |
self.llm = None
|
| 172 |
|
| 173 |
-
def
|
| 174 |
-
"""
|
| 175 |
try:
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
filename=GGUF_FILENAME,
|
| 179 |
-
repo_type="dataset"
|
| 180 |
-
)
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
verbose=False,
|
| 187 |
-
n_gpu_layers=0
|
| 188 |
-
)
|
| 189 |
|
| 190 |
-
#
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
| 194 |
|
| 195 |
-
|
| 196 |
-
|
| 197 |
return False
|
| 198 |
|
| 199 |
-
def
|
| 200 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
try:
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
self._log(f"載入 Transformers 模型: {FINETUNED_MODEL_PATH}")
|
| 206 |
-
|
| 207 |
-
# 載入你的微調模型
|
| 208 |
-
self.transformers_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL_PATH)
|
| 209 |
-
self.transformers_model = AutoModelForCausalLM.from_pretrained(
|
| 210 |
-
FINETUNED_MODEL_PATH,
|
| 211 |
-
torch_dtype=torch.float32, # CPU 使用 float32
|
| 212 |
-
device_map="cpu", # 強制使用 CPU
|
| 213 |
-
trust_remote_code=True # Qwen 模型可能需要
|
| 214 |
-
)
|
| 215 |
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
model=self.transformers_model,
|
| 220 |
-
tokenizer=self.transformers_tokenizer,
|
| 221 |
-
device=-1, # CPU
|
| 222 |
-
max_length=512,
|
| 223 |
-
do_sample=True,
|
| 224 |
temperature=0.1,
|
| 225 |
top_p=0.9,
|
| 226 |
-
|
|
|
|
| 227 |
)
|
| 228 |
|
| 229 |
-
|
| 230 |
-
self.backend = "transformers"
|
| 231 |
-
self._log("✅ Transformers 模型載入成功")
|
| 232 |
-
|
| 233 |
-
except Exception as e:
|
| 234 |
-
self._log(f"❌ Transformers 載入也失敗: {e}", "ERROR")
|
| 235 |
-
|
| 236 |
-
def huggingface_api_call(self, prompt: str) -> str:
|
| 237 |
-
"""生成 SQL:優先使用 transformers,其次 gguf,最後 fallback"""
|
| 238 |
-
# transformers 後端
|
| 239 |
-
if self.backend == "transformers" and hasattr(self, "generation_pipeline"):
|
| 240 |
-
try:
|
| 241 |
-
gen = self.generation_pipeline(
|
| 242 |
-
prompt,
|
| 243 |
-
max_new_tokens=350,
|
| 244 |
-
do_sample=True,
|
| 245 |
-
temperature=0.05,
|
| 246 |
-
top_p=0.9
|
| 247 |
-
)
|
| 248 |
-
# 盡量從 pipeline 結果提取文字
|
| 249 |
-
generated_text = ""
|
| 250 |
-
try:
|
| 251 |
-
if isinstance(gen, list) and gen:
|
| 252 |
-
first = gen[0]
|
| 253 |
-
if isinstance(first, dict) and "generated_text" in first:
|
| 254 |
-
generated_text = str(first["generated_text"]) # type: ignore[index]
|
| 255 |
-
else:
|
| 256 |
-
generated_text = str(first)
|
| 257 |
-
else:
|
| 258 |
-
generated_text = str(gen)
|
| 259 |
-
except Exception:
|
| 260 |
-
generated_text = str(gen)
|
| 261 |
-
# 若包含 prompt,裁切前綴
|
| 262 |
-
if isinstance(generated_text, str) and generated_text.startswith(prompt):
|
| 263 |
-
generated_text = generated_text[len(prompt):]
|
| 264 |
-
self._log(f"📝 提取出的生成文本: {generated_text.strip()}", "DEBUG")
|
| 265 |
-
|
| 266 |
-
lines = generated_text.strip().split('\n')
|
| 267 |
-
non_comment_lines = [line for line in lines if not line.strip().startswith('--')]
|
| 268 |
-
cleaned_text = "\n".join(non_comment_lines).strip()
|
| 269 |
-
if cleaned_text != generated_text.strip():
|
| 270 |
-
self._log(f"🧹 清理掉註解後的文本: {cleaned_text}", "DEBUG")
|
| 271 |
-
if cleaned_text and not re.match(r"^\s*select\b", cleaned_text, flags=re.IGNORECASE):
|
| 272 |
-
self._log("⚙️ 補上缺失的 'SELECT ' 起手以形成完整查詢", "DEBUG")
|
| 273 |
-
cleaned_text = "SELECT " + cleaned_text.lstrip()
|
| 274 |
-
return cleaned_text
|
| 275 |
-
except Exception as e:
|
| 276 |
-
self._log(f"❌ Transformers 生成失敗: {e}", "ERROR")
|
| 277 |
-
return ""
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
temperature=0.05,
|
| 286 |
-
top_p=0.9,
|
| 287 |
-
echo=False,
|
| 288 |
-
stop=["```"]
|
| 289 |
-
)
|
| 290 |
-
self._log(f"🧠 模型原始輸出 (Raw Output): {output}", "DEBUG")
|
| 291 |
-
if output and "choices" in output and len(output["choices"]) > 0:
|
| 292 |
-
generated_text = output["choices"][0]["text"]
|
| 293 |
-
self._log(f"📝 提取出的生成文本: {generated_text.strip()}", "DEBUG")
|
| 294 |
-
lines = str(generated_text).strip().split('\n')
|
| 295 |
-
non_comment_lines = [line for line in lines if not line.strip().startswith('--')]
|
| 296 |
-
cleaned_text = "\n".join(non_comment_lines).strip()
|
| 297 |
-
if cleaned_text != str(generated_text).strip():
|
| 298 |
-
self._log(f"🧹 清理掉註解後的文本: {cleaned_text}", "DEBUG")
|
| 299 |
-
if cleaned_text and not re.match(r"^\s*select\b", cleaned_text, flags=re.IGNORECASE):
|
| 300 |
-
self._log("⚙️ 補上缺失的 'SELECT ' 起手以形成完整查詢", "DEBUG")
|
| 301 |
-
cleaned_text = "SELECT " + cleaned_text.lstrip()
|
| 302 |
-
return cleaned_text
|
| 303 |
-
else:
|
| 304 |
-
self._log("❌ 模型的原始輸出格式不正確或為空。", "ERROR")
|
| 305 |
-
return ""
|
| 306 |
-
except Exception as e:
|
| 307 |
-
self._log(f"❌ GGUF 生成失敗: {e}", "ERROR")
|
| 308 |
return ""
|
| 309 |
|
| 310 |
-
# 後備:都不可用時,回退
|
| 311 |
-
self._log("模型未載入或不可用,返回 fallback SQL。", "ERROR")
|
| 312 |
-
return self._generate_fallback_sql(prompt)
|
| 313 |
-
|
| 314 |
-
def _load_gguf_model_fallback(self, model_path):
|
| 315 |
-
"""備用載入方式"""
|
| 316 |
-
try:
|
| 317 |
-
# 嘗試不同的參數組合
|
| 318 |
-
self.gguf_llm = Llama(
|
| 319 |
-
model_path=model_path,
|
| 320 |
-
n_ctx=512, # 更小的上下文
|
| 321 |
-
n_threads=4,
|
| 322 |
-
n_batch=128,
|
| 323 |
-
vocab_only=False,
|
| 324 |
-
use_mmap=True,
|
| 325 |
-
use_mlock=False,
|
| 326 |
-
verbose=True
|
| 327 |
-
)
|
| 328 |
-
self._log("✅ 備用方式載入成功")
|
| 329 |
except Exception as e:
|
| 330 |
-
self._log(f"
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
self.log_history.append(format_log(message, level))
|
| 335 |
-
print(format_log(message, level))
|
| 336 |
|
| 337 |
def _load_schema(self) -> Dict:
|
| 338 |
"""載入數據庫結構"""
|
|
@@ -340,91 +285,58 @@ class TextToSQLSystem:
|
|
| 340 |
schema_path = hf_hub_download(
|
| 341 |
repo_id=DATASET_REPO_ID,
|
| 342 |
filename="sqlite_schema_FULL.json",
|
| 343 |
-
repo_type="dataset"
|
|
|
|
| 344 |
)
|
| 345 |
with open(schema_path, "r", encoding="utf-8") as f:
|
| 346 |
schema_data = json.load(f)
|
| 347 |
|
| 348 |
-
|
| 349 |
-
self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格:")
|
| 350 |
for table_name, columns in schema_data.items():
|
| 351 |
self._log(f" - {table_name}: {len(columns)} 個欄位")
|
| 352 |
-
# 顯示前3個欄位作為範例
|
| 353 |
-
sample_cols = [col['name'] for col in columns[:3]]
|
| 354 |
-
self._log(f" 範例欄位: {', '.join(sample_cols)}")
|
| 355 |
|
| 356 |
-
self._log("
|
| 357 |
return schema_data
|
| 358 |
|
| 359 |
except Exception as e:
|
| 360 |
-
self._log(f"
|
| 361 |
return {}
|
| 362 |
|
| 363 |
-
# 也可以添加一個方法來檢查生成的 SQL 是否使用了正確的表格和欄位
|
| 364 |
-
def _analyze_sql_correctness(self, sql: str) -> Dict:
|
| 365 |
-
"""分析 SQL 的正確性"""
|
| 366 |
-
analysis = {
|
| 367 |
-
'valid_tables': [],
|
| 368 |
-
'invalid_tables': [],
|
| 369 |
-
'valid_columns': [],
|
| 370 |
-
'invalid_columns': [],
|
| 371 |
-
'suggestions': []
|
| 372 |
-
}
|
| 373 |
-
|
| 374 |
-
if not self.schema:
|
| 375 |
-
return analysis
|
| 376 |
-
|
| 377 |
-
# 提取 SQL 中的表格名稱
|
| 378 |
-
table_pattern = r'FROM\s+(\w+)|JOIN\s+(\w+)'
|
| 379 |
-
table_matches = re.findall(table_pattern, sql, re.IGNORECASE)
|
| 380 |
-
used_tables = [match[0] or match[1] for match in table_matches]
|
| 381 |
-
|
| 382 |
-
# 檢查表格是否存在
|
| 383 |
-
valid_tables = list(self.schema.keys())
|
| 384 |
-
for table in used_tables:
|
| 385 |
-
if table in valid_tables:
|
| 386 |
-
analysis['valid_tables'].append(table)
|
| 387 |
-
else:
|
| 388 |
-
analysis['invalid_tables'].append(table)
|
| 389 |
-
# 尋找相似的表格名稱
|
| 390 |
-
for valid_table in valid_tables:
|
| 391 |
-
if table.lower() in valid_table.lower() or valid_table.lower() in table.lower():
|
| 392 |
-
analysis['suggestions'].append(f"{table} -> {valid_table}")
|
| 393 |
-
|
| 394 |
-
# 提取欄位名稱(簡單版本)
|
| 395 |
-
column_pattern = r'SELECT\s+(.*?)\s+FROM|WHERE\s+(\w+)\s*[=<>]|GROUP BY\s+(\w+)|ORDER BY\s+(\w+)'
|
| 396 |
-
column_matches = re.findall(column_pattern, sql, re.IGNORECASE)
|
| 397 |
-
|
| 398 |
-
return analysis
|
| 399 |
-
|
| 400 |
def _encode_texts(self, texts):
|
| 401 |
"""編碼文本為嵌入向量"""
|
| 402 |
if isinstance(texts, str):
|
| 403 |
texts = [texts]
|
| 404 |
-
|
| 405 |
inputs = self.embed_tokenizer(texts, padding=True, truncation=True,
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
with torch.no_grad():
|
| 411 |
outputs = self.embed_model(**inputs)
|
| 412 |
|
| 413 |
# 使用平均池化
|
| 414 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 415 |
-
return embeddings.cpu()
|
| 416 |
|
| 417 |
def _load_and_index_dataset(self):
|
| 418 |
"""載入數據集並建立 FAISS 索引"""
|
| 419 |
try:
|
| 420 |
-
|
|
|
|
|
|
|
| 421 |
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
|
|
|
| 427 |
|
|
|
|
|
|
|
| 428 |
dataset = dataset.filter(
|
| 429 |
lambda ex: isinstance(ex.get("messages"), list)
|
| 430 |
and len(ex["messages"]) >= 2
|
|
@@ -434,10 +346,7 @@ class TextToSQLSystem:
|
|
| 434 |
)
|
| 435 |
)
|
| 436 |
|
| 437 |
-
|
| 438 |
-
self._log(
|
| 439 |
-
f"資料集清理: 原始 {original_count} 筆, 過濾後 {len(dataset)} 筆, 移除 {original_count - len(dataset)} 筆"
|
| 440 |
-
)
|
| 441 |
|
| 442 |
if len(dataset) == 0:
|
| 443 |
self._log("清理後資料集為空,無法建立索引。", "ERROR")
|
|
@@ -446,14 +355,19 @@ class TextToSQLSystem:
|
|
| 446 |
corpus = [item['messages'][0]['content'] for item in dataset]
|
| 447 |
self._log(f"正在編碼 {len(corpus)} 個問題...")
|
| 448 |
|
| 449 |
-
#
|
| 450 |
embeddings_list = []
|
| 451 |
-
batch_size =
|
| 452 |
|
| 453 |
for i in range(0, len(corpus), batch_size):
|
| 454 |
batch_texts = corpus[i:i+batch_size]
|
| 455 |
batch_embeddings = self._encode_texts(batch_texts)
|
| 456 |
embeddings_list.append(batch_embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
self._log(f"已編碼 {min(i+batch_size, len(corpus))}/{len(corpus)}")
|
| 458 |
|
| 459 |
all_embeddings = torch.cat(embeddings_list, dim=0).numpy()
|
|
@@ -462,11 +376,15 @@ class TextToSQLSystem:
|
|
| 462 |
index = faiss.IndexFlatIP(all_embeddings.shape[1])
|
| 463 |
index.add(all_embeddings.astype('float32'))
|
| 464 |
|
| 465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
return dataset, index
|
| 467 |
|
| 468 |
except Exception as e:
|
| 469 |
-
self._log(f"
|
| 470 |
return None, None
|
| 471 |
|
| 472 |
def _identify_relevant_tables(self, question: str) -> List[str]:
|
|
@@ -497,12 +415,8 @@ class TextToSQLSystem:
|
|
| 497 |
|
| 498 |
return relevant_tables[:3] # 最多返回3個相關表格
|
| 499 |
|
| 500 |
-
# 請將這整個函數複製到您的 TextToSQLSystem class 內部
|
| 501 |
-
|
| 502 |
def _format_relevant_schema(self, table_names: List[str]) -> str:
|
| 503 |
-
"""
|
| 504 |
-
生成一個簡化的、不易被模型錯誤模仿的 Schema 字符串。
|
| 505 |
-
"""
|
| 506 |
if not self.schema:
|
| 507 |
return "No schema available.\n"
|
| 508 |
|
|
@@ -522,257 +436,17 @@ class TextToSQLSystem:
|
|
| 522 |
formatted = ""
|
| 523 |
for table in real_table_names:
|
| 524 |
if table in self.schema:
|
| 525 |
-
# 使用簡單的 "Table: ..." 和 "Columns: ..." 格式
|
| 526 |
formatted += f"Table: {table}\n"
|
| 527 |
cols_str = []
|
| 528 |
-
# 只顯示前
|
| 529 |
-
for col in self.schema[table][:
|
| 530 |
col_name = col['name']
|
| 531 |
col_type = col['type']
|
| 532 |
-
|
| 533 |
-
# 將描述信息放在括號裡
|
| 534 |
-
if col_desc:
|
| 535 |
-
cols_str.append(f"{col_name} ({col_type}, {col_desc})")
|
| 536 |
-
else:
|
| 537 |
-
cols_str.append(f"{col_name} ({col_type})")
|
| 538 |
formatted += f"Columns: {', '.join(cols_str)}\n\n"
|
| 539 |
|
| 540 |
return formatted.strip()
|
| 541 |
|
| 542 |
-
# 在 class TextToSQLSystem 內
|
| 543 |
-
|
| 544 |
-
def _validate_and_fix_sql(self, question: str, raw_response: str) -> Tuple[Optional[str], str]:
|
| 545 |
-
"""
|
| 546 |
-
(V29 / 穩健正則 + 智能計數 最終版)
|
| 547 |
-
一個多層次的SQL生成引擎。它優先使用基於規則的動態模板生成器,
|
| 548 |
-
如果無法匹配,則回退到解析和修正AI模型的輸出。
|
| 549 |
-
- 使用更簡潔、穩健的正則表達式來捕獲實體名稱。
|
| 550 |
-
- 根據問題是關於「報告」還是「測試項目」來智能地決定計數目標。
|
| 551 |
-
"""
|
| 552 |
-
q_lower = question.lower()
|
| 553 |
-
|
| 554 |
-
# ==============================================================================
|
| 555 |
-
# 第零層:統一實體識別引擎 (Unified Entity Recognition Engine)
|
| 556 |
-
# ==============================================================================
|
| 557 |
-
entity_match_data = None
|
| 558 |
-
# 包含了繁簡體兼容和更穩健的模式
|
| 559 |
-
entity_patterns = [
|
| 560 |
-
# 模式1: 匹配 "类型 + ID" - (保持不變)
|
| 561 |
-
{'pattern': r"(買家|买家|buyer)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.BuyerID', 'type': '買家ID'},
|
| 562 |
-
{'pattern': r"(申請方|申请方|申請廠商|申请厂商|applicant)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.ApplicantID', 'type': '申請方ID'},
|
| 563 |
-
{'pattern': r"(付款方|付款厂商|invoiceto)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.InvoiceToID', 'type': '付款方ID'},
|
| 564 |
-
{'pattern': r"(代理商|agent)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.AgentID', 'type': '代理商ID'},
|
| 565 |
-
|
| 566 |
-
# 模式2: 匹配 "類型 + 名稱" - (簡化了模式,使其更穩健)
|
| 567 |
-
{'pattern': r"(買家|买家|buyer|客戶)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.BuyerName', 'type': '買家'},
|
| 568 |
-
{'pattern': r"(申請方|申请方|申請廠商|申请厂商|applicant)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.ApplicantName', 'type': '申請方'},
|
| 569 |
-
{'pattern': r"(付款方|付款厂商|invoiceto)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.InvoiceToName', 'type': '付款方'},
|
| 570 |
-
{'pattern': r"(代理商|agent)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.AgentName', 'type': '代理商'},
|
| 571 |
-
|
| 572 |
-
# 模式3: 单独匹配一个 ID - (保持不變)
|
| 573 |
-
{'pattern': r"\b([A-Z]\d{4}[A-Z])\b", 'column': 'sd.ApplicantID', 'type': 'ID'}
|
| 574 |
-
]
|
| 575 |
-
|
| 576 |
-
for p in entity_patterns:
|
| 577 |
-
match = re.search(p['pattern'], question, re.IGNORECASE)
|
| 578 |
-
if match:
|
| 579 |
-
entity_value = match.group(2) if len(match.groups()) > 1 else match.group(1)
|
| 580 |
-
entity_match_data = {
|
| 581 |
-
"type": p['type'],
|
| 582 |
-
"name": entity_value.strip().upper(),
|
| 583 |
-
"column": p['column']
|
| 584 |
-
}
|
| 585 |
-
break
|
| 586 |
-
|
| 587 |
-
# ==============================================================================
|
| 588 |
-
# 第一層:模組化意圖偵測與動態SQL組合
|
| 589 |
-
# ==============================================================================
|
| 590 |
-
|
| 591 |
-
intents = {}
|
| 592 |
-
sql_components = {
|
| 593 |
-
'select': [], 'from': "", 'joins': [], 'where': [],
|
| 594 |
-
'group_by': [], 'order_by': [], 'log_parts': []
|
| 595 |
-
}
|
| 596 |
-
|
| 597 |
-
# --- 運行一系列獨立的意圖偵測器 ---
|
| 598 |
-
|
| 599 |
-
# 偵測器 2.1: 核心動作意圖
|
| 600 |
-
if any(kw in q_lower for kw in ['幾份', '多少', '數量', '總數', 'how many', 'count']):
|
| 601 |
-
intents['action'] = 'count'
|
| 602 |
-
# 智能決定計數目標
|
| 603 |
-
if "測試項目" in question or "test item" in q_lower:
|
| 604 |
-
sql_components['select'].append("COUNT(jip.ItemCode) AS item_count")
|
| 605 |
-
sql_components['log_parts'].append("測試項目總數")
|
| 606 |
-
else: # 預設是計數報告
|
| 607 |
-
sql_components['select'].append("COUNT(DISTINCT jt.JobNo) AS report_count")
|
| 608 |
-
sql_components['log_parts'].append("報告總數")
|
| 609 |
-
elif any(kw in q_lower for kw in ['報告號碼', '報告清單', '列出報告', 'report number', 'list of reports']):
|
| 610 |
-
intents['action'] = 'list'
|
| 611 |
-
sql_components['select'].append("jt.JobNo, jt.ReportAuthorization")
|
| 612 |
-
sql_components['order_by'].append("jt.ReportAuthorization DESC")
|
| 613 |
-
sql_components['log_parts'].append("報告列表")
|
| 614 |
-
|
| 615 |
-
# 偵測器 2.2: 時間意圖
|
| 616 |
-
year_match = re.search(r'(\d{4})\s*年?', question)
|
| 617 |
-
month_match = re.search(r'(\d{1,2})\s*月', question)
|
| 618 |
-
if year_match:
|
| 619 |
-
year = year_match.group(1)
|
| 620 |
-
sql_components['where'].append(f"strftime('%Y', jt.ReportAuthorization) = '{year}'")
|
| 621 |
-
sql_components['log_parts'].append(f"{year}年")
|
| 622 |
-
if month_match:
|
| 623 |
-
month = month_match.group(1).zfill(2)
|
| 624 |
-
sql_components['where'].append(f"strftime('%m', jt.ReportAuthorization) = '{month}'")
|
| 625 |
-
sql_components['log_parts'].append(f"{month}月")
|
| 626 |
-
|
| 627 |
-
# 偵測器 2.3: 實體意圖
|
| 628 |
-
if entity_match_data:
|
| 629 |
-
if "TSR53SampleDescription" not in " ".join(sql_components['joins']):
|
| 630 |
-
sql_components['joins'].append("JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo")
|
| 631 |
-
entity_name, column_name = entity_match_data["name"], entity_match_data["column"]
|
| 632 |
-
match_operator = "=" if column_name.endswith("ID") else "LIKE"
|
| 633 |
-
entity_value = f"'%{entity_name}%'" if match_operator == "LIKE" else f"'{entity_name}'"
|
| 634 |
-
sql_components['where'].append(f"{column_name} {match_operator} {entity_value}")
|
| 635 |
-
sql_components['log_parts'].append(entity_match_data["type"] + ":" + entity_name)
|
| 636 |
-
if intents.get('action') == 'list':
|
| 637 |
-
sql_components['select'].append("sd.BuyerName")
|
| 638 |
-
|
| 639 |
-
# 偵測器 2.4: 評級意圖
|
| 640 |
-
if 'fail' in q_lower or '失敗' in q_lower:
|
| 641 |
-
if "TSR53SampleDescription" not in " ".join(sql_components['joins']):
|
| 642 |
-
sql_components['joins'].append("JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo")
|
| 643 |
-
sql_components['where'].append("sd.OverallRating = 'Fail'")
|
| 644 |
-
sql_components['log_parts'].append("Fail")
|
| 645 |
-
elif 'pass' in q_lower or '通過' in q_lower:
|
| 646 |
-
if "TSR53SampleDescription" not in " ".join(sql_components['joins']):
|
| 647 |
-
sql_components['joins'].append("JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo")
|
| 648 |
-
sql_components['where'].append("sd.OverallRating = 'Pass'")
|
| 649 |
-
sql_components['log_parts'].append("Pass")
|
| 650 |
-
|
| 651 |
-
# 偵測器 2.5: 實驗組 (LabGroup) 意圖 (帶有別名映射)
|
| 652 |
-
lab_group_mapping = {'A': 'TA', 'B': 'TB', 'C': 'TC', 'D': 'TD', 'E': 'TE', 'Y': 'TY'}
|
| 653 |
-
lab_group_match = re.search(r'([A-Z]{1,2})組', question, re.IGNORECASE)
|
| 654 |
-
if lab_group_match:
|
| 655 |
-
user_input_group = lab_group_match.group(1).upper()
|
| 656 |
-
db_lab_group = lab_group_mapping.get(user_input_group, user_input_group)
|
| 657 |
-
sql_components['joins'].append("JOIN JobItemsInProgress AS jip ON jt.JobNo = jip.JobNo")
|
| 658 |
-
sql_components['where'].append(f"jip.LabGroup = '{db_lab_group}'")
|
| 659 |
-
sql_components['log_parts'].append(f"{user_input_group}組(->{db_lab_group})")
|
| 660 |
-
|
| 661 |
-
# --- 2.6: 兩年份比較模板(優先級:高) ---
|
| 662 |
-
# 偵測『比較/vs/對比/相較/相比』字樣,擷取兩個年份與(可選)買家名稱
|
| 663 |
-
compare_hit = any(kw in q_lower for kw in ["比較", "對比", "相較", "相比", "vs", "versus"])
|
| 664 |
-
years_found = re.findall(r"(20\d{2})", question)
|
| 665 |
-
years_unique = []
|
| 666 |
-
for y in years_found:
|
| 667 |
-
if y not in years_unique:
|
| 668 |
-
years_unique.append(y)
|
| 669 |
-
if compare_hit and len(years_unique) >= 2:
|
| 670 |
-
year_a, year_b = years_unique[0], years_unique[1]
|
| 671 |
-
# 嘗試抓買家名稱(英文/數字/符號),若沒有則不加 buyer 條件
|
| 672 |
-
buyer_name = None
|
| 673 |
-
# 1) 優先解析明確條件:BuyerName LIKE '%...%'
|
| 674 |
-
m_like = re.search(r"BuyerName\s+LIKE\s*'%([^']+)%'", question, re.IGNORECASE)
|
| 675 |
-
if m_like:
|
| 676 |
-
buyer_name = m_like.group(1).strip()
|
| 677 |
-
else:
|
| 678 |
-
# 2) 解析自然語言:避免 'BuyerName' 被誤判成 'buyer'
|
| 679 |
-
buyer_match = re.search(r"(?:買家|买家|客戶|客户|\bbuyer\b(?!name))\s*[::]?\s*([A-Za-z0-9&.\- ]+)", question, re.IGNORECASE)
|
| 680 |
-
if buyer_match:
|
| 681 |
-
buyer_name = buyer_match.group(1).strip()
|
| 682 |
-
|
| 683 |
-
# 判斷偏向金額或件數
|
| 684 |
-
amount_intent = any(kw in q_lower for kw in ["金額", "金钱", "amount", "營收", "業績", "營業額", "銷售額", "revenue"])
|
| 685 |
-
|
| 686 |
-
if amount_intent:
|
| 687 |
-
# 金額版:需要發票表,依架構命名使用 TSR53Invoice 與 LocalAmount;與樣本描述以 JobNo 關聯
|
| 688 |
-
sql = (
|
| 689 |
-
"SELECT strftime('%Y', jt.ReportAuthorization) AS year, "
|
| 690 |
-
"SUM(COALESCE(inv.LocalAmount, 0)) AS total_amount "
|
| 691 |
-
"FROM JobTimeline AS jt "
|
| 692 |
-
"JOIN TSR53SampleDescription AS sd ON sd.JobNo = jt.JobNo "
|
| 693 |
-
"LEFT JOIN TSR53Invoice AS inv ON inv.JobNo = jt.JobNo "
|
| 694 |
-
"WHERE jt.ReportAuthorization IS NOT NULL "
|
| 695 |
-
f"AND strftime('%Y', jt.ReportAuthorization) IN ('{year_a}', '{year_b}') "
|
| 696 |
-
)
|
| 697 |
-
if buyer_name:
|
| 698 |
-
sql += f"AND sd.BuyerName LIKE '%{buyer_name}%' "
|
| 699 |
-
sql += "GROUP BY year ORDER BY year;"
|
| 700 |
-
return self._finalize_sql(sql, f"模板覆寫: 兩年份金額比較 {year_a} vs {year_b}" )
|
| 701 |
-
else:
|
| 702 |
-
# 件數版:以報告數量為主,去重 JobNo
|
| 703 |
-
sql = (
|
| 704 |
-
"SELECT strftime('%Y', jt.ReportAuthorization) AS year, "
|
| 705 |
-
"COUNT(DISTINCT jt.JobNo) AS report_count "
|
| 706 |
-
"FROM JobTimeline AS jt "
|
| 707 |
-
"JOIN TSR53SampleDescription AS sd ON sd.JobNo = jt.JobNo "
|
| 708 |
-
"WHERE jt.ReportAuthorization IS NOT NULL "
|
| 709 |
-
f"AND strftime('%Y', jt.ReportAuthorization) IN ('{year_a}', '{year_b}') "
|
| 710 |
-
)
|
| 711 |
-
if buyer_name:
|
| 712 |
-
sql += f"AND sd.BuyerName LIKE '%{buyer_name}%' "
|
| 713 |
-
sql += "GROUP BY year ORDER BY year;"
|
| 714 |
-
return self._finalize_sql(sql, f"模板覆寫: 兩年份件數比較 {year_a} vs {year_b}" )
|
| 715 |
-
|
| 716 |
-
# --- 3. 判斷是否觸發了模板,並動態組合 SQL ---
|
| 717 |
-
if 'action' in intents:
|
| 718 |
-
sql_components['from'] = "FROM JobTimeline AS jt"
|
| 719 |
-
# 只要有任何篩選條件,就加上報告已授權的基礎限制
|
| 720 |
-
if sql_components['where']:
|
| 721 |
-
sql_components['where'].insert(0, "jt.ReportAuthorization IS NOT NULL")
|
| 722 |
-
|
| 723 |
-
select_clause = "SELECT " + ", ".join(sorted(list(set(sql_components['select']))))
|
| 724 |
-
from_clause = sql_components['from']
|
| 725 |
-
joins_clause = " ".join(sql_components['joins'])
|
| 726 |
-
where_clause = "WHERE " + " AND ".join(sql_components['where']) if sql_components['where'] else ""
|
| 727 |
-
orderby_clause = "ORDER BY " + ", ".join(sql_components['order_by']) if sql_components['order_by'] else ""
|
| 728 |
-
|
| 729 |
-
template_sql = f"{select_clause} {from_clause} {joins_clause} {where_clause} {orderby_clause};"
|
| 730 |
-
|
| 731 |
-
query_log = " ".join(sql_components['log_parts'])
|
| 732 |
-
self._log(f"🔄 偵測到組合意圖【{query_log}】,啟用動態模板。", "INFO")
|
| 733 |
-
return self._finalize_sql(template_sql, f"模板覆寫: {query_log} 查詢")
|
| 734 |
-
|
| 735 |
-
# ==============================================================================
|
| 736 |
-
# 第二层:AI 生成修正流程 (Fallback)
|
| 737 |
-
# ==============================================================================
|
| 738 |
-
self._log("未觸發任何模板,嘗試解析並修正 AI 輸出...", "INFO")
|
| 739 |
-
|
| 740 |
-
parsed_sql = parse_sql_from_response(raw_response)
|
| 741 |
-
if not parsed_sql:
|
| 742 |
-
self._log(f"❌ 未能從模型回應中解析出任何 SQL。原始回應: {raw_response}", "ERROR")
|
| 743 |
-
return None, f"無法解析SQL。原始回應:\n{raw_response}"
|
| 744 |
-
|
| 745 |
-
self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG")
|
| 746 |
-
|
| 747 |
-
fixed_sql = " " + parsed_sql.strip() + " "
|
| 748 |
-
fixes_applied_fallback = []
|
| 749 |
-
|
| 750 |
-
dialect_corrections = {r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)"}
|
| 751 |
-
for pattern, replacement in dialect_corrections.items():
|
| 752 |
-
if re.search(pattern, fixed_sql, re.IGNORECASE):
|
| 753 |
-
fixed_sql = re.sub(pattern, replacement, fixed_sql, flags=re.IGNORECASE)
|
| 754 |
-
fixes_applied_fallback.append(f"修正方言: {pattern}")
|
| 755 |
-
|
| 756 |
-
schema_corrections = {'TSR53Report':'TSR53SampleDescription', 'TSR53InvoiceReportNo':'JobNo', 'TSR53ReportNo':'JobNo', 'TSR53InvoiceNo':'JobNo', 'TSR53InvoiceCreditNoteNo':'InvoiceCreditNoteNo', 'TSR53InvoiceLocalAmount':'LocalAmount', 'Status':'OverallRating', 'ReportStatus':'OverallRating'}
|
| 757 |
-
for wrong, correct in schema_corrections.items():
|
| 758 |
-
pattern = r'\b' + re.escape(wrong) + r'\b'
|
| 759 |
-
if re.search(pattern, fixed_sql, re.IGNORECASE):
|
| 760 |
-
fixed_sql = re.sub(pattern, correct, fixed_sql, flags=re.IGNORECASE)
|
| 761 |
-
fixes_applied_fallback.append(f"映射 Schema: '{wrong}' -> '{correct}'")
|
| 762 |
-
|
| 763 |
-
log_msg = "AI 生成並成功修正" if fixes_applied_fallback else "AI 生成且無需修正"
|
| 764 |
-
return self._finalize_sql(fixed_sql, log_msg)
|
| 765 |
-
|
| 766 |
-
def _finalize_sql(self, sql: str, log_message: str) -> Tuple[str, str]:
|
| 767 |
-
"""一個輔助函數,用於清理最終的SQL並記錄成功日誌。"""
|
| 768 |
-
final_sql = sql.strip()
|
| 769 |
-
if not final_sql.endswith(';'):
|
| 770 |
-
final_sql += ';'
|
| 771 |
-
final_sql = re.sub(r'\s+', ' ', final_sql).strip()
|
| 772 |
-
self._log(f"✅ SQL 已生成 ({log_message})", "INFO")
|
| 773 |
-
self._log(f" - 最終 SQL: {final_sql}", "DEBUG")
|
| 774 |
-
return final_sql, "生成成功"
|
| 775 |
-
|
| 776 |
def find_most_similar(self, question: str, top_k: int) -> List[Dict]:
|
| 777 |
"""使用 FAISS 快速檢索相似問題"""
|
| 778 |
if self.faiss_index is None or self.dataset is None:
|
|
@@ -792,16 +466,14 @@ class TextToSQLSystem:
|
|
| 792 |
if len(results) >= top_k:
|
| 793 |
break
|
| 794 |
|
| 795 |
-
|
| 796 |
-
idx
|
| 797 |
-
|
| 798 |
-
if idx >= len(self.dataset): # 確保索引有效
|
| 799 |
continue
|
| 800 |
|
| 801 |
item = self.dataset[idx]
|
| 802 |
-
# 防呆:若樣本不完整則跳過
|
| 803 |
if not isinstance(item.get('messages'), list) or len(item['messages']) < 2:
|
| 804 |
continue
|
|
|
|
| 805 |
q_content = (item['messages'][0].get('content') or '').strip()
|
| 806 |
a_content = (item['messages'][1].get('content') or '').strip()
|
| 807 |
if not q_content or not a_content:
|
|
@@ -824,18 +496,12 @@ class TextToSQLSystem:
|
|
| 824 |
return results
|
| 825 |
|
| 826 |
except Exception as e:
|
| 827 |
-
self._log(f"
|
| 828 |
return []
|
| 829 |
|
| 830 |
-
# in class TextToSQLSystem:
|
| 831 |
-
|
| 832 |
def _build_prompt(self, user_q: str, examples: List[Dict]) -> str:
|
| 833 |
-
"""
|
| 834 |
-
建立一個高度結構化、以任務為導向的提示詞,使用清晰的標題分隔符。
|
| 835 |
-
"""
|
| 836 |
relevant_tables = self._identify_relevant_tables(user_q)
|
| 837 |
-
|
| 838 |
-
# 使用我們新的、更簡單的 schema 格式化函數
|
| 839 |
schema_str = self._format_relevant_schema(relevant_tables)
|
| 840 |
|
| 841 |
example_str = "No example available."
|
|
@@ -843,8 +509,9 @@ class TextToSQLSystem:
|
|
| 843 |
best_example = examples[0]
|
| 844 |
example_str = f"Question: {best_example['question']}\nSQL:\n```sql\n{best_example['sql']}\n```"
|
| 845 |
|
| 846 |
-
#
|
| 847 |
-
prompt = f"""
|
|
|
|
| 848 |
|
| 849 |
### SCHEMA ###
|
| 850 |
{schema_str}
|
|
@@ -852,22 +519,241 @@ class TextToSQLSystem:
|
|
| 852 |
### EXAMPLE ###
|
| 853 |
{example_str}
|
| 854 |
|
| 855 |
-
###
|
| 856 |
-
|
| 857 |
-
|
|
|
|
| 858 |
```sql
|
| 859 |
SELECT
|
| 860 |
"""
|
| 861 |
-
self._log(f"📏 Prompt 長度: {len(prompt)} 字符")
|
| 862 |
-
# 不再需要複雜的長度截斷邏輯,因為 schema 已經被簡化
|
| 863 |
return prompt
|
| 864 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 865 |
|
| 866 |
def _generate_fallback_sql(self, prompt: str) -> str:
|
| 867 |
"""當模型不可用時的備用 SQL 生成"""
|
| 868 |
prompt_lower = prompt.lower()
|
| 869 |
|
| 870 |
-
# 簡單的關鍵詞匹配生成基本 SQL
|
| 871 |
if "統計" in prompt or "數量" in prompt or "多少" in prompt:
|
| 872 |
if "月" in prompt:
|
| 873 |
return "SELECT strftime('%Y-%m', completed_time) as month, COUNT(*) as count FROM jobtimeline GROUP BY month ORDER BY month;"
|
|
@@ -875,100 +761,96 @@ SELECT
|
|
| 875 |
return "SELECT applicant, COUNT(*) as count FROM tsr53sampledescription GROUP BY applicant ORDER BY count DESC;"
|
| 876 |
else:
|
| 877 |
return "SELECT COUNT(*) as total_count FROM jobtimeline WHERE completed_time IS NOT NULL;"
|
| 878 |
-
|
| 879 |
elif "金額" in prompt or "總額" in prompt:
|
| 880 |
return "SELECT SUM(amount) as total_amount FROM tsr53invoice;"
|
| 881 |
-
|
| 882 |
elif "評級" in prompt or "pass" in prompt_lower or "fail" in prompt_lower:
|
| 883 |
return "SELECT rating, COUNT(*) as count FROM tsr53sampledescription GROUP BY rating;"
|
| 884 |
-
|
| 885 |
else:
|
| 886 |
return "SELECT * FROM jobtimeline LIMIT 10;"
|
| 887 |
|
| 888 |
-
def _validate_model_file(self, model_path):
|
| 889 |
-
"""驗證模型檔案完整性"""
|
| 890 |
-
try:
|
| 891 |
-
if not os.path.exists(model_path):
|
| 892 |
-
return False
|
| 893 |
-
|
| 894 |
-
# 檢查檔案大小(至少應該有幾MB)
|
| 895 |
-
file_size = os.path.getsize(model_path)
|
| 896 |
-
if file_size < 10 * 1024 * 1024: # 小於 10MB 可能有問題
|
| 897 |
-
return False
|
| 898 |
-
|
| 899 |
-
# 檢查 GGUF 檔案頭部
|
| 900 |
-
with open(model_path, 'rb') as f:
|
| 901 |
-
header = f.read(8)
|
| 902 |
-
if not header.startswith(b'GGUF'):
|
| 903 |
-
return False
|
| 904 |
-
|
| 905 |
-
return True
|
| 906 |
-
except Exception:
|
| 907 |
-
return False
|
| 908 |
-
|
| 909 |
-
# in class TextToSQLSystem:
|
| 910 |
-
|
| 911 |
def process_question(self, question: str) -> Tuple[str, str]:
|
| 912 |
-
"""處理使用者問題
|
| 913 |
# 檢查緩存
|
| 914 |
if question in self.query_cache:
|
| 915 |
-
self._log("
|
| 916 |
return self.query_cache[question]
|
| 917 |
|
| 918 |
self.log_history = []
|
| 919 |
-
self._log(f"
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
|
|
|
|
|
|
|
|
|
| 934 |
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
# 將原本 prompt 的結尾替換成我們的修正指令
|
| 939 |
-
prompt = prompt.rsplit("SQL:\n```sql", 1)[0] + correction_prompt
|
| 940 |
|
|
|
|
|
|
|
|
|
|
| 941 |
|
| 942 |
-
|
| 943 |
-
|
| 944 |
-
|
|
|
|
|
|
|
| 945 |
|
| 946 |
-
|
| 947 |
-
|
| 948 |
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
result = (final_sql, status_message)
|
| 952 |
-
self.query_cache[question] = result # 緩存成功結果
|
| 953 |
-
return result
|
| 954 |
|
| 955 |
-
|
|
|
|
| 956 |
|
| 957 |
-
|
| 958 |
-
self._log("❌ 所有嘗試均失敗,返回錯誤訊息。", "ERROR")
|
| 959 |
-
final_fallback_message = "模型多次嘗試後仍無法生成有效的SQL。"
|
| 960 |
-
return (final_fallback_message, "生成失敗")
|
| 961 |
|
| 962 |
-
# ==================== Gradio
|
|
|
|
| 963 |
text_to_sql_system = TextToSQLSystem()
|
| 964 |
|
| 965 |
-
def process_query(q: str):
|
| 966 |
-
if not q.strip():
|
| 967 |
-
return "", "等待輸入", "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 968 |
|
| 969 |
sql, status = text_to_sql_system.process_question(q)
|
| 970 |
-
logs = "\n".join(text_to_sql_system.log_history[-
|
| 971 |
-
|
| 972 |
return sql, status, logs
|
| 973 |
|
| 974 |
# 範例問題
|
|
@@ -980,36 +862,39 @@ examples = [
|
|
| 980 |
"A組昨天完成了多少個測試項目?"
|
| 981 |
]
|
| 982 |
|
| 983 |
-
with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
|
| 984 |
-
gr.Markdown("#
|
| 985 |
-
gr.Markdown("輸入自然語言問題,自動生成SQL
|
| 986 |
|
| 987 |
with gr.Row():
|
| 988 |
with gr.Column(scale=2):
|
| 989 |
-
inp = gr.Textbox(lines=3, label="
|
| 990 |
-
btn = gr.Button("
|
| 991 |
status = gr.Textbox(label="狀態", interactive=False)
|
|
|
|
|
|
|
| 992 |
|
| 993 |
with gr.Column(scale=3):
|
| 994 |
-
sql_out = gr.Code(label="
|
| 995 |
|
| 996 |
-
with gr.Accordion("
|
| 997 |
-
logs = gr.Textbox(lines=
|
| 998 |
|
| 999 |
# 範例區
|
| 1000 |
gr.Examples(
|
| 1001 |
examples=examples,
|
| 1002 |
inputs=inp,
|
| 1003 |
-
label="
|
| 1004 |
)
|
| 1005 |
|
| 1006 |
# 綁定事件
|
| 1007 |
-
btn.click(process_query, inputs=[inp], outputs=[sql_out, status, logs])
|
| 1008 |
-
inp.submit(process_query, inputs=[inp], outputs=[sql_out, status, logs])
|
| 1009 |
|
| 1010 |
if __name__ == "__main__":
|
| 1011 |
demo.launch(
|
| 1012 |
server_name="0.0.0.0",
|
| 1013 |
server_port=7860,
|
| 1014 |
-
share=
|
|
|
|
| 1015 |
)
|
|
|
|
| 4 |
import json
|
| 5 |
import torch
|
| 6 |
import numpy as np
|
| 7 |
+
import psutil
|
| 8 |
+
import gc
|
| 9 |
+
import tempfile
|
| 10 |
from datetime import datetime
|
| 11 |
from datasets import load_dataset
|
| 12 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 14 |
from typing import List, Dict, Tuple, Optional
|
| 15 |
import faiss
|
| 16 |
from functools import lru_cache
|
|
|
|
| 17 |
|
| 18 |
# 使用 transformers 替代 sentence-transformers
|
| 19 |
from transformers import AutoModel, AutoTokenizer
|
| 20 |
import torch.nn.functional as F
|
| 21 |
|
| 22 |
+
# ==================== 配置參數 ====================
|
| 23 |
DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
|
| 24 |
GGUF_REPO_ID = "Paul720810/gguf-models"
|
| 25 |
+
GGUF_FILENAME = "qwen2-7b-instruct-sql-finetuned-stable.q4_k_m.gguf"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
|
|
|
|
|
|
| 27 |
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 28 |
|
| 29 |
+
# 可配置 GPU(HF 免費方案通常只有 CPU)
|
| 30 |
+
USE_GPU = str(os.getenv("USE_GPU", "0")).lower() in {"1", "true", "yes", "y"}
|
| 31 |
+
try:
|
| 32 |
+
N_GPU_LAYERS = int(os.getenv("N_GPU_LAYERS", "0"))
|
| 33 |
+
except Exception:
|
| 34 |
+
N_GPU_LAYERS = 0
|
| 35 |
+
DEVICE = "cuda" if (USE_GPU and torch.cuda.is_available()) else "cpu"
|
| 36 |
+
|
| 37 |
+
# CPU 專用優化(可由環境變數覆蓋)
|
| 38 |
+
def _int_env(name: str, default_val: int) -> int:
|
| 39 |
+
try:
|
| 40 |
+
return int(os.getenv(name, str(default_val)))
|
| 41 |
+
except Exception:
|
| 42 |
+
return default_val
|
| 43 |
+
|
| 44 |
+
THREADS = _int_env("THREADS", min(4, os.cpu_count() or 2)) # llama.cpp 執行緒數
|
| 45 |
+
CTX = _int_env("CTX", 768 if DEVICE == "cpu" else 1024) # 上下文長度
|
| 46 |
+
MAX_TOKENS = _int_env("MAX_TOKENS", 60) # 生成 token 上限
|
| 47 |
+
FEW_SHOT_EXAMPLES_COUNT = _int_env("FEW_SHOT", 0 if DEVICE == "cpu" else 1)
|
| 48 |
+
ENABLE_INDEX = str(os.getenv("ENABLE_INDEX", "0" if DEVICE == "cpu" else "1")).lower() in {"1", "true", "yes", "y"}
|
| 49 |
+
EMBED_BATCH = _int_env("EMBED_BATCH", 8 if DEVICE == "cpu" else 16)
|
| 50 |
+
|
| 51 |
+
# 使用 /tmp 作為暫存目錄
|
| 52 |
+
TEMP_DIR = "/tmp/text_to_sql_cache"
|
| 53 |
+
os.makedirs(TEMP_DIR, exist_ok=True)
|
| 54 |
+
|
| 55 |
print("=" * 60)
|
| 56 |
+
print("Text-to-SQL 系統啟動中 (HF 版本)...")
|
| 57 |
+
print(f"數據集: {DATASET_REPO_ID}")
|
| 58 |
+
print(f"嵌入模型: {EMBED_MODEL_NAME}")
|
| 59 |
+
print(f"設備: {DEVICE} (USE_GPU={USE_GPU}, N_GPU_LAYERS={N_GPU_LAYERS})")
|
| 60 |
+
print(f"THREADS={THREADS}, CTX={CTX}, MAX_TOKENS={MAX_TOKENS}, FEW_SHOT={FEW_SHOT_EXAMPLES_COUNT}, ENABLE_INDEX={ENABLE_INDEX}, EMBED_BATCH={EMBED_BATCH}")
|
| 61 |
+
print(f"暫存目錄: {TEMP_DIR}")
|
| 62 |
print("=" * 60)
|
| 63 |
|
| 64 |
# ==================== 工具函數 ====================
|
|
|
|
| 68 |
def format_log(message: str, level: str = "INFO") -> str:
|
| 69 |
return f"[{get_current_time()}] [{level.upper()}] {message}"
|
| 70 |
|
| 71 |
+
def check_memory_usage():
|
| 72 |
+
"""檢查內存使用情況 - 簡化版本不依賴 psutil"""
|
| 73 |
+
try:
|
| 74 |
+
# 使用 /proc/meminfo 獲取內存信息 (Linux 環境)
|
| 75 |
+
with open('/proc/meminfo', 'r') as f:
|
| 76 |
+
lines = f.readlines()
|
| 77 |
|
| 78 |
+
mem_info = {}
|
| 79 |
+
for line in lines:
|
| 80 |
+
if line.startswith(('MemTotal:', 'MemFree:', 'MemAvailable:')):
|
| 81 |
+
key, value = line.split(':')
|
| 82 |
+
mem_info[key.strip()] = int(value.strip().split()[0])
|
| 83 |
|
| 84 |
+
total_gb = mem_info.get('MemTotal', 0) / (1024**2)
|
| 85 |
+
available_gb = mem_info.get('MemAvailable', mem_info.get('MemFree', 0)) / (1024**2)
|
| 86 |
+
used_percent = ((total_gb - available_gb) / total_gb * 100) if total_gb > 0 else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
+
return f"內存使用率: {used_percent:.1f}% (可用: {available_gb:.1f}GB/{total_gb:.1f}GB)"
|
| 89 |
+
except:
|
| 90 |
+
# 如果無法讀取 /proc/meminfo,返回簡單信息
|
| 91 |
+
return "內存信息: 無法獲取詳細信息"
|
|
|
|
| 92 |
|
| 93 |
+
def parse_sql_from_response(response_text: str) -> Optional[str]:
|
| 94 |
+
"""從模型輸出提取 SQL"""
|
| 95 |
+
if not response_text:
|
| 96 |
return None
|
| 97 |
|
| 98 |
+
response_text = response_text.strip()
|
| 99 |
+
|
| 100 |
+
# 1. 先找 ```sql ... ```
|
| 101 |
+
match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
|
| 102 |
+
if match:
|
| 103 |
+
return match.group(1).strip()
|
| 104 |
+
|
| 105 |
+
# 2. 找任何 ``` 包圍的內容
|
| 106 |
+
match = re.search(r"```\s*\n?(.*?)\n?```", response_text, re.DOTALL)
|
| 107 |
+
if match:
|
| 108 |
+
sql_candidate = match.group(1).strip()
|
| 109 |
+
if sql_candidate.upper().startswith('SELECT'):
|
| 110 |
+
return sql_candidate
|
| 111 |
+
|
| 112 |
+
# 3. 找 SQL 語句(更寬鬆的匹配)
|
| 113 |
+
match = re.search(r"(SELECT\s+.*?;)", response_text, re.DOTALL | re.IGNORECASE)
|
| 114 |
+
if match:
|
| 115 |
+
return match.group(1).strip()
|
| 116 |
+
|
| 117 |
+
# 4. 找沒有分號的 SQL
|
| 118 |
+
match = re.search(r"(SELECT\s+.*?)(?=\n\n|\n```|$|\n[^,\s])", response_text, re.DOTALL | re.IGNORECASE)
|
| 119 |
+
if match:
|
| 120 |
+
sql = match.group(1).strip()
|
| 121 |
+
if not sql.endswith(';'):
|
| 122 |
+
sql += ';'
|
| 123 |
+
return sql
|
| 124 |
+
|
| 125 |
+
# 5. 如果包含 SELECT,嘗試提取整行
|
| 126 |
+
if 'SELECT' in response_text.upper():
|
| 127 |
+
lines = response_text.split('\n')
|
| 128 |
+
for line in lines:
|
| 129 |
+
line = line.strip()
|
| 130 |
+
if line.upper().startswith('SELECT'):
|
| 131 |
+
if not line.endswith(';'):
|
| 132 |
+
line += ';'
|
| 133 |
+
return line
|
| 134 |
+
|
| 135 |
return None
|
| 136 |
|
| 137 |
# ==================== Text-to-SQL 核心類 ====================
|
|
|
|
| 140 |
self.log_history = []
|
| 141 |
self._log("初始化系統...")
|
| 142 |
self.query_cache = {}
|
| 143 |
+
self.embed_device = DEVICE
|
| 144 |
+
|
| 145 |
+
# 檢查內存狀況
|
| 146 |
+
self._log(check_memory_usage())
|
| 147 |
|
| 148 |
# 1. 載入嵌入模型
|
| 149 |
self._log(f"載入嵌入模型: {embed_model_name}")
|
| 150 |
self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
|
| 151 |
self.embed_model = AutoModel.from_pretrained(embed_model_name)
|
| 152 |
+
try:
|
| 153 |
+
self.embed_model.to(self.embed_device)
|
| 154 |
+
self._log(f"嵌入模型設備: {self.embed_device}")
|
| 155 |
+
except Exception as e:
|
| 156 |
+
self._log(f"將嵌入模型移動到設備失敗: {e}", "WARNING")
|
| 157 |
+
self.embed_device = "cpu"
|
| 158 |
|
| 159 |
# 2. 載入數據庫結構
|
| 160 |
self.schema = self._load_schema()
|
|
|
|
| 162 |
# 3. 載入數據集並建立索引
|
| 163 |
self.dataset, self.faiss_index = self._load_and_index_dataset()
|
| 164 |
|
| 165 |
+
# 4. 載入 GGUF 模型(新增錯誤處理)
|
| 166 |
self._load_gguf_model()
|
| 167 |
|
| 168 |
+
self._log("系統初始化完成")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
def _log(self, message: str, level: str = "INFO"):
|
| 171 |
+
self.log_history.append(format_log(message, level))
|
| 172 |
+
print(format_log(message, level))
|
| 173 |
|
| 174 |
def _load_gguf_model(self):
|
| 175 |
+
"""載入 GGUF 模型,針對 Paperspace 環境優化"""
|
| 176 |
try:
|
| 177 |
+
self._log("開始下載 GGUF 模���到 /tmp...")
|
| 178 |
+
|
| 179 |
+
# 檢查模型是否已存在於 /tmp
|
| 180 |
+
model_cache_path = os.path.join(TEMP_DIR, GGUF_FILENAME)
|
| 181 |
+
|
| 182 |
+
if os.path.exists(model_cache_path) and self._validate_model_file(model_cache_path):
|
| 183 |
+
self._log(f"發現快取模型: {model_cache_path}")
|
| 184 |
+
model_path = model_cache_path
|
| 185 |
+
else:
|
| 186 |
+
self._log("下載新模型...")
|
| 187 |
+
model_path = hf_hub_download(
|
| 188 |
+
repo_id=GGUF_REPO_ID,
|
| 189 |
+
filename=GGUF_FILENAME,
|
| 190 |
+
repo_type="dataset",
|
| 191 |
+
cache_dir=TEMP_DIR,
|
| 192 |
+
resume_download=True
|
| 193 |
+
)
|
| 194 |
+
self._log(f"模型下載完成: {model_path}")
|
| 195 |
+
|
| 196 |
+
# 檢查內存情況
|
| 197 |
+
self._log(check_memory_usage())
|
| 198 |
|
| 199 |
+
# 使用 CPU 友好的參數載入模型(可選 GPU layers)
|
| 200 |
+
ngl = N_GPU_LAYERS if (DEVICE == "cuda" and N_GPU_LAYERS > 0) else 0
|
| 201 |
+
self._log(f"載入 GGUF 模型 (n_gpu_layers={ngl}, n_threads={THREADS}, n_ctx={CTX})...")
|
| 202 |
+
self.llm = Llama(
|
| 203 |
model_path=model_path,
|
| 204 |
+
n_ctx=CTX, # 上下文長度(CPU 默認更小)
|
| 205 |
+
n_threads=THREADS, # 使用多執行緒
|
| 206 |
+
n_batch=256, # 批處理大小
|
| 207 |
+
verbose=False,
|
| 208 |
+
n_gpu_layers=ngl, # 可選 GPU 加速
|
| 209 |
+
use_mmap=True, # 使用內存映射減少內存占用
|
| 210 |
+
use_mlock=False, # 不鎖定內存
|
| 211 |
+
low_vram=True # 啟用低內存模式
|
| 212 |
)
|
| 213 |
|
| 214 |
+
# 簡單測試模型
|
| 215 |
+
test_result = self.llm("SELECT", max_tokens=3)
|
| 216 |
+
self._log("GGUF 模型載入成功")
|
| 217 |
+
|
| 218 |
+
# 再次檢查內存
|
| 219 |
+
self._log(check_memory_usage())
|
| 220 |
|
| 221 |
except Exception as e:
|
| 222 |
+
self._log(f"GGUF 載入失敗: {e}", "ERROR")
|
| 223 |
+
self._log("系統將無法生成 SQL。請檢查模型檔案或內存情況。", "CRITICAL")
|
| 224 |
self.llm = None
|
| 225 |
|
| 226 |
+
def _validate_model_file(self, model_path):
|
| 227 |
+
"""驗證模型檔案完整性"""
|
| 228 |
try:
|
| 229 |
+
if not os.path.exists(model_path):
|
| 230 |
+
return False
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
+
# 檢查檔案大小(至少應該有幾百MB)
|
| 233 |
+
file_size = os.path.getsize(model_path)
|
| 234 |
+
if file_size < 50 * 1024 * 1024: # 小於 50MB 可能有問題
|
| 235 |
+
return False
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
+
# 檢查 GGUF 檔案頭部
|
| 238 |
+
with open(model_path, 'rb') as f:
|
| 239 |
+
header = f.read(8)
|
| 240 |
+
if not header.startswith(b'GGUF'):
|
| 241 |
+
return False
|
| 242 |
|
| 243 |
+
return True
|
| 244 |
+
except Exception:
|
| 245 |
return False
|
| 246 |
|
| 247 |
+
def huggingface_api_call(self, prompt: str) -> str:
|
| 248 |
+
"""調用 GGUF 模型,並加入詳細的原始輸出日誌"""
|
| 249 |
+
if self.llm is None:
|
| 250 |
+
self._log("模型未載入,返回 fallback SQL。", "ERROR")
|
| 251 |
+
return self._generate_fallback_sql(prompt)
|
| 252 |
+
|
| 253 |
try:
|
| 254 |
+
# 清理垃圾收集
|
| 255 |
+
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
+
output = self.llm(
|
| 258 |
+
prompt,
|
| 259 |
+
max_tokens=MAX_TOKENS, # 生成長度可配置
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
temperature=0.1,
|
| 261 |
top_p=0.9,
|
| 262 |
+
echo=False,
|
| 263 |
+
stop=["```", ";", "\n\n", "</s>"],
|
| 264 |
)
|
| 265 |
|
| 266 |
+
self._log(f"模型原始輸出: {str(output)[:200]}...", "DEBUG")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
+
if output and "choices" in output and len(output["choices"]) > 0:
|
| 269 |
+
generated_text = output["choices"][0]["text"]
|
| 270 |
+
self._log(f"提取出的生成文本: {generated_text.strip()}", "DEBUG")
|
| 271 |
+
return generated_text.strip()
|
| 272 |
+
else:
|
| 273 |
+
self._log("模型的原始輸出格式不正確或為空。", "ERROR")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
return ""
|
| 275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
except Exception as e:
|
| 277 |
+
self._log(f"模型生成過程中發生嚴重錯誤: {e}", "CRITICAL")
|
| 278 |
+
import traceback
|
| 279 |
+
self._log(traceback.format_exc(), "DEBUG")
|
| 280 |
+
return ""
|
|
|
|
|
|
|
| 281 |
|
| 282 |
def _load_schema(self) -> Dict:
|
| 283 |
"""載入數據庫結構"""
|
|
|
|
| 285 |
schema_path = hf_hub_download(
|
| 286 |
repo_id=DATASET_REPO_ID,
|
| 287 |
filename="sqlite_schema_FULL.json",
|
| 288 |
+
repo_type="dataset",
|
| 289 |
+
cache_dir=TEMP_DIR
|
| 290 |
)
|
| 291 |
with open(schema_path, "r", encoding="utf-8") as f:
|
| 292 |
schema_data = json.load(f)
|
| 293 |
|
| 294 |
+
self._log(f"Schema 載入成功,包含 {len(schema_data)} 個表格:")
|
|
|
|
| 295 |
for table_name, columns in schema_data.items():
|
| 296 |
self._log(f" - {table_name}: {len(columns)} 個欄位")
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
+
self._log("數據庫結構載入完成")
|
| 299 |
return schema_data
|
| 300 |
|
| 301 |
except Exception as e:
|
| 302 |
+
self._log(f"載入 schema 失敗: {e}", "ERROR")
|
| 303 |
return {}
|
| 304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
def _encode_texts(self, texts):
|
| 306 |
"""編碼文本為嵌入向量"""
|
| 307 |
if isinstance(texts, str):
|
| 308 |
texts = [texts]
|
|
|
|
| 309 |
inputs = self.embed_tokenizer(texts, padding=True, truncation=True,
|
| 310 |
+
return_tensors="pt", max_length=512)
|
| 311 |
+
# 移動到對應設備
|
| 312 |
+
try:
|
| 313 |
+
inputs = {k: v.to(self.embed_device) for k, v in inputs.items()}
|
| 314 |
+
except Exception:
|
| 315 |
+
pass
|
| 316 |
|
| 317 |
with torch.no_grad():
|
| 318 |
outputs = self.embed_model(**inputs)
|
| 319 |
|
| 320 |
# 使用平均池化
|
| 321 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 322 |
+
return embeddings.detach().cpu()
|
| 323 |
|
| 324 |
def _load_and_index_dataset(self):
|
| 325 |
"""載入數據集並建立 FAISS 索引"""
|
| 326 |
try:
|
| 327 |
+
if not ENABLE_INDEX:
|
| 328 |
+
self._log("已禁用相似範例索引(ENABLE_INDEX=0)。啟動更快,將不使用 few-shot。")
|
| 329 |
+
return None, None
|
| 330 |
|
| 331 |
+
dataset = load_dataset(
|
| 332 |
+
DATASET_REPO_ID,
|
| 333 |
+
data_files="training_data.jsonl",
|
| 334 |
+
split="train",
|
| 335 |
+
cache_dir=TEMP_DIR
|
| 336 |
+
)
|
| 337 |
|
| 338 |
+
# 過濾不完整樣本
|
| 339 |
+
original_count = len(dataset)
|
| 340 |
dataset = dataset.filter(
|
| 341 |
lambda ex: isinstance(ex.get("messages"), list)
|
| 342 |
and len(ex["messages"]) >= 2
|
|
|
|
| 346 |
)
|
| 347 |
)
|
| 348 |
|
| 349 |
+
self._log(f"資料集清理: 原始 {original_count} 筆, 過濾後 {len(dataset)} 筆")
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
if len(dataset) == 0:
|
| 352 |
self._log("清理後資料集為空,無法建立索引。", "ERROR")
|
|
|
|
| 355 |
corpus = [item['messages'][0]['content'] for item in dataset]
|
| 356 |
self._log(f"正在編碼 {len(corpus)} 個問題...")
|
| 357 |
|
| 358 |
+
# 批量編碼以節省內存
|
| 359 |
embeddings_list = []
|
| 360 |
+
batch_size = EMBED_BATCH # 可配置的批次大小(CPU 預設更小)
|
| 361 |
|
| 362 |
for i in range(0, len(corpus), batch_size):
|
| 363 |
batch_texts = corpus[i:i+batch_size]
|
| 364 |
batch_embeddings = self._encode_texts(batch_texts)
|
| 365 |
embeddings_list.append(batch_embeddings)
|
| 366 |
+
|
| 367 |
+
# 清理內存
|
| 368 |
+
if i % (batch_size * 4) == 0:
|
| 369 |
+
gc.collect()
|
| 370 |
+
|
| 371 |
self._log(f"已編碼 {min(i+batch_size, len(corpus))}/{len(corpus)}")
|
| 372 |
|
| 373 |
all_embeddings = torch.cat(embeddings_list, dim=0).numpy()
|
|
|
|
| 376 |
index = faiss.IndexFlatIP(all_embeddings.shape[1])
|
| 377 |
index.add(all_embeddings.astype('float32'))
|
| 378 |
|
| 379 |
+
# 清理內存
|
| 380 |
+
del embeddings_list, all_embeddings
|
| 381 |
+
gc.collect()
|
| 382 |
+
|
| 383 |
+
self._log("向量索引建立完成")
|
| 384 |
return dataset, index
|
| 385 |
|
| 386 |
except Exception as e:
|
| 387 |
+
self._log(f"載入數據失敗: {e}", "ERROR")
|
| 388 |
return None, None
|
| 389 |
|
| 390 |
def _identify_relevant_tables(self, question: str) -> List[str]:
|
|
|
|
| 415 |
|
| 416 |
return relevant_tables[:3] # 最多返回3個相關表格
|
| 417 |
|
|
|
|
|
|
|
| 418 |
def _format_relevant_schema(self, table_names: List[str]) -> str:
|
| 419 |
+
"""生成一個簡化的 Schema 字符串"""
|
|
|
|
|
|
|
| 420 |
if not self.schema:
|
| 421 |
return "No schema available.\n"
|
| 422 |
|
|
|
|
| 436 |
formatted = ""
|
| 437 |
for table in real_table_names:
|
| 438 |
if table in self.schema:
|
|
|
|
| 439 |
formatted += f"Table: {table}\n"
|
| 440 |
cols_str = []
|
| 441 |
+
# 只顯示前 8 個關鍵欄位以節省內存
|
| 442 |
+
for col in self.schema[table][:8]:
|
| 443 |
col_name = col['name']
|
| 444 |
col_type = col['type']
|
| 445 |
+
cols_str.append(f"{col_name} ({col_type})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
formatted += f"Columns: {', '.join(cols_str)}\n\n"
|
| 447 |
|
| 448 |
return formatted.strip()
|
| 449 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
def find_most_similar(self, question: str, top_k: int) -> List[Dict]:
|
| 451 |
"""使用 FAISS 快速檢索相似問題"""
|
| 452 |
if self.faiss_index is None or self.dataset is None:
|
|
|
|
| 466 |
if len(results) >= top_k:
|
| 467 |
break
|
| 468 |
|
| 469 |
+
idx = int(idx)
|
| 470 |
+
if idx >= len(self.dataset):
|
|
|
|
|
|
|
| 471 |
continue
|
| 472 |
|
| 473 |
item = self.dataset[idx]
|
|
|
|
| 474 |
if not isinstance(item.get('messages'), list) or len(item['messages']) < 2:
|
| 475 |
continue
|
| 476 |
+
|
| 477 |
q_content = (item['messages'][0].get('content') or '').strip()
|
| 478 |
a_content = (item['messages'][1].get('content') or '').strip()
|
| 479 |
if not q_content or not a_content:
|
|
|
|
| 496 |
return results
|
| 497 |
|
| 498 |
except Exception as e:
|
| 499 |
+
self._log(f"檢索失敗: {e}", "ERROR")
|
| 500 |
return []
|
| 501 |
|
|
|
|
|
|
|
| 502 |
def _build_prompt(self, user_q: str, examples: List[Dict]) -> str:
|
| 503 |
+
"""建立簡化的提示詞"""
|
|
|
|
|
|
|
| 504 |
relevant_tables = self._identify_relevant_tables(user_q)
|
|
|
|
|
|
|
| 505 |
schema_str = self._format_relevant_schema(relevant_tables)
|
| 506 |
|
| 507 |
example_str = "No example available."
|
|
|
|
| 509 |
best_example = examples[0]
|
| 510 |
example_str = f"Question: {best_example['question']}\nSQL:\n```sql\n{best_example['sql']}\n```"
|
| 511 |
|
| 512 |
+
# 簡化的 prompt,減少 token 使用
|
| 513 |
+
prompt = f"""### TASK ###
|
| 514 |
+
Generate SQLite query for the question below.
|
| 515 |
|
| 516 |
### SCHEMA ###
|
| 517 |
{schema_str}
|
|
|
|
| 519 |
### EXAMPLE ###
|
| 520 |
{example_str}
|
| 521 |
|
| 522 |
+
### QUESTION ###
|
| 523 |
+
{user_q}
|
| 524 |
+
|
| 525 |
+
SQL:
|
| 526 |
```sql
|
| 527 |
SELECT
|
| 528 |
"""
|
|
|
|
|
|
|
| 529 |
return prompt
|
| 530 |
|
| 531 |
+
def _rule_based_sql(self, question: str) -> Optional[str]:
|
| 532 |
+
"""規則先行:對常見查詢用模板直接生成 SQL,繞過 LLM。"""
|
| 533 |
+
q = (question or "").strip()
|
| 534 |
+
q_lower = q.lower()
|
| 535 |
+
|
| 536 |
+
# 兩年比較(完成數量、每月)
|
| 537 |
+
m = re.search(r"(20\d{2}).{0,6}(?:與|和|跟)\s*(20\d{2}).{0,10}(比較|對比).{0,10}(完成|報告|數量|件|工單)", q)
|
| 538 |
+
if m:
|
| 539 |
+
y1, y2 = m.group(1), m.group(2)
|
| 540 |
+
return (
|
| 541 |
+
"SELECT strftime('%Y-%m', completed_time) AS month, "
|
| 542 |
+
f"SUM(CASE WHEN strftime('%Y', completed_time)='{y1}' THEN 1 ELSE 0 END) AS count_{y1}, "
|
| 543 |
+
f"SUM(CASE WHEN strftime('%Y', completed_time)='{y2}' THEN 1 ELSE 0 END) AS count_{y2} "
|
| 544 |
+
"FROM jobtimeline "
|
| 545 |
+
f"WHERE strftime('%Y', completed_time) IN ('{y1}','{y2}') "
|
| 546 |
+
"GROUP BY month ORDER BY month;"
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
# 指定年份每月完成數量
|
| 550 |
+
m = re.search(r"(20\d{2})年.*每月.*(完成|報告|數量|件|工單)", q)
|
| 551 |
+
if m:
|
| 552 |
+
year = m.group(1)
|
| 553 |
+
return (
|
| 554 |
+
"SELECT strftime('%Y-%m', completed_time) AS month, COUNT(*) AS count "
|
| 555 |
+
"FROM jobtimeline "
|
| 556 |
+
f"WHERE strftime('%Y', completed_time)='{year}' "
|
| 557 |
+
"GROUP BY month ORDER BY month;"
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
# 評級分布(Pass/Fail)
|
| 561 |
+
if ("評級" in q) or ("pass" in q_lower) or ("fail" in q_lower):
|
| 562 |
+
return "SELECT rating, COUNT(*) AS count FROM tsr53sampledescription GROUP BY rating;"
|
| 563 |
+
|
| 564 |
+
# 金額最高 Top N(預設 10)
|
| 565 |
+
m = re.search(r"金額.*?(?:最高|前|top)\s*(\d+)?", q_lower)
|
| 566 |
+
if m:
|
| 567 |
+
n = m.group(1) or "10"
|
| 568 |
+
return f"SELECT * FROM tsr53invoice ORDER BY amount DESC LIMIT {n};"
|
| 569 |
+
|
| 570 |
+
# 客戶工作單數量最多 Top N
|
| 571 |
+
m = re.search(r"客戶.*?(?:最多|top|前)\s*(\d+)?", q_lower)
|
| 572 |
+
if m:
|
| 573 |
+
n = m.group(1) or "10"
|
| 574 |
+
return f"SELECT applicant, COUNT(*) AS count FROM tsr53sampledescription GROUP BY applicant ORDER BY count DESC LIMIT {n};"
|
| 575 |
+
|
| 576 |
+
# 昨天完成多少
|
| 577 |
+
if "昨天" in q:
|
| 578 |
+
return (
|
| 579 |
+
"SELECT COUNT(*) AS count FROM jobtimeline "
|
| 580 |
+
"WHERE date(completed_time)=date('now','-1 day');"
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
return None
|
| 584 |
+
|
| 585 |
+
def _finalize_sql(self, sql_text: str, status: str) -> Tuple[str, str]:
|
| 586 |
+
"""最終整理 SQL:補分號、去除多餘空白並回傳 (sql, 狀態)。"""
|
| 587 |
+
try:
|
| 588 |
+
sql_clean = (sql_text or "").strip()
|
| 589 |
+
if sql_clean and not sql_clean.endswith(";"):
|
| 590 |
+
sql_clean += ";"
|
| 591 |
+
return sql_clean, status
|
| 592 |
+
except Exception as e:
|
| 593 |
+
self._log(f"最終整理 SQL 失敗: {e}", "ERROR")
|
| 594 |
+
return (sql_text or ""), status
|
| 595 |
+
|
| 596 |
+
def _validate_and_fix_sql(self, question: str, raw_response: str) -> Tuple[Optional[str], str]:
|
| 597 |
+
"""
|
| 598 |
+
(V29 / 穩健正則 + 智能計數) 多層次 SQL 生成:
|
| 599 |
+
1) 嘗試規則/模板動態組合
|
| 600 |
+
2) 失敗則解析 AI 輸出並做方言/Schema 修正
|
| 601 |
+
回傳: (sql 或 None, 狀態描述)
|
| 602 |
+
"""
|
| 603 |
+
q = question or ""
|
| 604 |
+
q_lower = q.lower()
|
| 605 |
+
|
| 606 |
+
# 先嘗試內建的規則先行器
|
| 607 |
+
rb = self._rule_based_sql(q)
|
| 608 |
+
if rb:
|
| 609 |
+
self._log("_validate_and_fix_sql 命中規則模板")
|
| 610 |
+
return self._finalize_sql(rb, "規則生成")
|
| 611 |
+
|
| 612 |
+
# 統一實體識別(簡化版)
|
| 613 |
+
entity_match_data = None
|
| 614 |
+
entity_patterns = [
|
| 615 |
+
{'pattern': r"(買家|买家|buyer)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.BuyerID', 'type': '買家ID'},
|
| 616 |
+
{'pattern': r"(申請方|申请方|申請廠商|申请厂商|applicant)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.ApplicantID', 'type': '申請方ID'},
|
| 617 |
+
{'pattern': r"(付款方|付款厂商|invoiceto)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.InvoiceToID', 'type': '付款方ID'},
|
| 618 |
+
{'pattern': r"(代理商|agent)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.AgentID', 'type': '代理商ID'},
|
| 619 |
+
{'pattern': r"(買家|买家|buyer|客戶)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.BuyerName', 'type': '買家'},
|
| 620 |
+
{'pattern': r"(申請方|申请方|申請廠商|申请厂商|applicant)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.ApplicantName', 'type': '申請方'},
|
| 621 |
+
{'pattern': r"(付款方|付款厂商|invoiceto)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.InvoiceToName', 'type': '付款方'},
|
| 622 |
+
{'pattern': r"(代理商|agent)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.AgentName', 'type': '代理商'},
|
| 623 |
+
{'pattern': r"\b([A-Z]\d{4}[A-Z])\b", 'column': 'sd.ApplicantID', 'type': 'ID'}
|
| 624 |
+
]
|
| 625 |
+
for p in entity_patterns:
|
| 626 |
+
m = re.search(p['pattern'], q, re.IGNORECASE)
|
| 627 |
+
if m:
|
| 628 |
+
entity_value = m.group(2) if len(m.groups()) > 1 else m.group(1)
|
| 629 |
+
entity_match_data = {"type": p['type'], "name": entity_value.strip().upper(), "column": p['column']}
|
| 630 |
+
break
|
| 631 |
+
|
| 632 |
+
# 模組化意圖偵測與動態 SQL 組合
|
| 633 |
+
intents: Dict[str, str] = {}
|
| 634 |
+
sql = {
|
| 635 |
+
'select': [], 'from': '', 'joins': [], 'where': [],
|
| 636 |
+
'group_by': [], 'order_by': [], 'log_parts': []
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
# 動作意圖:count / list
|
| 640 |
+
if any(kw in q_lower for kw in ['幾份', '多少', '數量', '總數', 'how many', 'count']):
|
| 641 |
+
intents['action'] = 'count'
|
| 642 |
+
if ("測試項目" in q) or ("test item" in q_lower):
|
| 643 |
+
sql['select'].append("COUNT(jip.ItemCode) AS item_count")
|
| 644 |
+
sql['log_parts'].append("測試項目總數")
|
| 645 |
+
else:
|
| 646 |
+
sql['select'].append("COUNT(DISTINCT jt.JobNo) AS report_count")
|
| 647 |
+
sql['log_parts'].append("報告總數")
|
| 648 |
+
elif any(kw in q_lower for kw in ['報告號碼', '報告清單', '列出報告', 'report number', 'list of reports']):
|
| 649 |
+
intents['action'] = 'list'
|
| 650 |
+
sql['select'].append("jt.JobNo, jt.ReportAuthorization")
|
| 651 |
+
sql['order_by'].append("jt.ReportAuthorization DESC")
|
| 652 |
+
sql['log_parts'].append("報告列表")
|
| 653 |
+
|
| 654 |
+
# 時間意圖:年/月
|
| 655 |
+
ym = re.search(r'(\d{4})\s*年?', q)
|
| 656 |
+
mm = re.search(r'(\d{1,2})\s*月', q)
|
| 657 |
+
if ym:
|
| 658 |
+
year = ym.group(1)
|
| 659 |
+
sql['where'].append(f"strftime('%Y', jt.ReportAuthorization) = '{year}'")
|
| 660 |
+
sql['log_parts'].append(f"{year}年")
|
| 661 |
+
if mm:
|
| 662 |
+
month = mm.group(1).zfill(2)
|
| 663 |
+
sql['where'].append(f"strftime('%m', jt.ReportAuthorization) = '{month}'")
|
| 664 |
+
sql['log_parts'].append(f"{month}月")
|
| 665 |
+
|
| 666 |
+
# 實體意圖
|
| 667 |
+
if entity_match_data:
|
| 668 |
+
if "TSR53SampleDescription" not in " ".join(sql['joins']):
|
| 669 |
+
sql['joins'].append("JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo")
|
| 670 |
+
entity_name, column_name = entity_match_data['name'], entity_match_data['column']
|
| 671 |
+
match_op = '=' if column_name.endswith('ID') else 'LIKE'
|
| 672 |
+
entity_val = f"'%{entity_name}%'" if match_op == 'LIKE' else f"'{entity_name}'"
|
| 673 |
+
sql['where'].append(f"{column_name} {match_op} {entity_val}")
|
| 674 |
+
sql['log_parts'].append(entity_match_data['type'] + ":" + entity_name)
|
| 675 |
+
if intents.get('action') == 'list':
|
| 676 |
+
sql['select'].append("sd.BuyerName")
|
| 677 |
+
|
| 678 |
+
# 評級意圖
|
| 679 |
+
if ('fail' in q_lower) or ('失敗' in q_lower):
|
| 680 |
+
if "TSR53SampleDescription" not in " ".join(sql['joins']):
|
| 681 |
+
sql['joins'].append("JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo")
|
| 682 |
+
sql['where'].append("sd.OverallRating = 'Fail'")
|
| 683 |
+
sql['log_parts'].append("Fail")
|
| 684 |
+
elif ('pass' in q_lower) or ('通過' in q_lower):
|
| 685 |
+
if "TSR53SampleDescription" not in " ".join(sql['joins']):
|
| 686 |
+
sql['joins'].append("JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo")
|
| 687 |
+
sql['where'].append("sd.OverallRating = 'Pass'")
|
| 688 |
+
sql['log_parts'].append("Pass")
|
| 689 |
+
|
| 690 |
+
# 實驗組 (LabGroup)
|
| 691 |
+
lab_group_mapping = {'A': 'TA', 'B': 'TB', 'C': 'TC', 'D': 'TD', 'E': 'TE', 'Y': 'TY'}
|
| 692 |
+
lgm = re.search(r'([A-Z]{1,2})組', q, re.IGNORECASE)
|
| 693 |
+
if lgm:
|
| 694 |
+
user_group = lgm.group(1).upper()
|
| 695 |
+
db_group = lab_group_mapping.get(user_group, user_group)
|
| 696 |
+
sql['joins'].append("JOIN JobItemsInProgress AS jip ON jt.JobNo = jip.JobNo")
|
| 697 |
+
sql['where'].append(f"jip.LabGroup = '{db_group}'")
|
| 698 |
+
sql['log_parts'].append(f"{user_group}組(->{db_group})")
|
| 699 |
+
|
| 700 |
+
# 若動作已決定,組裝模板 SQL
|
| 701 |
+
if 'action' in intents:
|
| 702 |
+
sql['from'] = "FROM JobTimeline AS jt"
|
| 703 |
+
if sql['where']:
|
| 704 |
+
sql['where'].insert(0, "jt.ReportAuthorization IS NOT NULL")
|
| 705 |
+
select_clause = "SELECT " + ", ".join(sorted(list(set(sql['select'])))) if sql['select'] else "SELECT *"
|
| 706 |
+
from_clause = sql['from']
|
| 707 |
+
joins_clause = " ".join(sql['joins'])
|
| 708 |
+
where_clause = ("WHERE " + " AND ".join(sql['where'])) if sql['where'] else ""
|
| 709 |
+
orderby_clause = ("ORDER BY " + ", ".join(sql['order_by'])) if sql['order_by'] else ""
|
| 710 |
+
template_sql = f"{select_clause} {from_clause} {joins_clause} {where_clause} {orderby_clause};"
|
| 711 |
+
query_log = " ".join(sql['log_parts'])
|
| 712 |
+
self._log(f"🔄 偵測到組合意圖【{query_log}】,啟用動態模板。")
|
| 713 |
+
return self._finalize_sql(template_sql, f"模板覆寫: {query_log} 查詢")
|
| 714 |
+
|
| 715 |
+
# 第二層:解析 AI 輸出並修正
|
| 716 |
+
self._log("未觸發任何模板,嘗試解析並修正 AI 輸出…")
|
| 717 |
+
parsed_sql = parse_sql_from_response(raw_response)
|
| 718 |
+
if not parsed_sql:
|
| 719 |
+
self._log(f"❌ 未能從模型回應中解析出任何 SQL。原始回應: {raw_response}", "ERROR")
|
| 720 |
+
return None, f"無法解析SQL。原始回應:\n{raw_response}"
|
| 721 |
+
|
| 722 |
+
self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG")
|
| 723 |
+
fixed_sql = " " + parsed_sql.strip() + " "
|
| 724 |
+
fixes_applied = []
|
| 725 |
+
|
| 726 |
+
# 方言修正
|
| 727 |
+
dialect_corrections = {r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)"}
|
| 728 |
+
for pat, rep in dialect_corrections.items():
|
| 729 |
+
if re.search(pat, fixed_sql, re.IGNORECASE):
|
| 730 |
+
fixed_sql = re.sub(pat, rep, fixed_sql, flags=re.IGNORECASE)
|
| 731 |
+
fixes_applied.append(f"修正方言: {pat}")
|
| 732 |
+
|
| 733 |
+
# Schema 名稱修正(常見別名 => 真實欄位)
|
| 734 |
+
schema_map = {
|
| 735 |
+
'TSR53Report':'TSR53SampleDescription',
|
| 736 |
+
'TSR53InvoiceReportNo':'JobNo',
|
| 737 |
+
'TSR53ReportNo':'JobNo',
|
| 738 |
+
'TSR53InvoiceNo':'JobNo',
|
| 739 |
+
'TSR53InvoiceCreditNoteNo':'InvoiceCreditNoteNo',
|
| 740 |
+
'TSR53InvoiceLocalAmount':'LocalAmount',
|
| 741 |
+
'Status':'OverallRating',
|
| 742 |
+
'ReportStatus':'OverallRating'
|
| 743 |
+
}
|
| 744 |
+
for wrong, correct in schema_map.items():
|
| 745 |
+
pat = r'\b' + re.escape(wrong) + r'\b'
|
| 746 |
+
if re.search(pat, fixed_sql, re.IGNORECASE):
|
| 747 |
+
fixed_sql = re.sub(pat, correct, fixed_sql, flags=re.IGNORECASE)
|
| 748 |
+
fixes_applied.append(f"映射 Schema: '{wrong}' -> '{correct}'")
|
| 749 |
+
|
| 750 |
+
status = "AI 生成並成功修正" if fixes_applied else "AI 生成且無需修正"
|
| 751 |
+
return self._finalize_sql(fixed_sql, status)
|
| 752 |
|
| 753 |
def _generate_fallback_sql(self, prompt: str) -> str:
|
| 754 |
"""當模型不可用時的備用 SQL 生成"""
|
| 755 |
prompt_lower = prompt.lower()
|
| 756 |
|
|
|
|
| 757 |
if "統計" in prompt or "數量" in prompt or "多少" in prompt:
|
| 758 |
if "月" in prompt:
|
| 759 |
return "SELECT strftime('%Y-%m', completed_time) as month, COUNT(*) as count FROM jobtimeline GROUP BY month ORDER BY month;"
|
|
|
|
| 761 |
return "SELECT applicant, COUNT(*) as count FROM tsr53sampledescription GROUP BY applicant ORDER BY count DESC;"
|
| 762 |
else:
|
| 763 |
return "SELECT COUNT(*) as total_count FROM jobtimeline WHERE completed_time IS NOT NULL;"
|
|
|
|
| 764 |
elif "金額" in prompt or "總額" in prompt:
|
| 765 |
return "SELECT SUM(amount) as total_amount FROM tsr53invoice;"
|
|
|
|
| 766 |
elif "評級" in prompt or "pass" in prompt_lower or "fail" in prompt_lower:
|
| 767 |
return "SELECT rating, COUNT(*) as count FROM tsr53sampledescription GROUP BY rating;"
|
|
|
|
| 768 |
else:
|
| 769 |
return "SELECT * FROM jobtimeline LIMIT 10;"
|
| 770 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 771 |
def process_question(self, question: str) -> Tuple[str, str]:
|
| 772 |
+
"""處理使用者問題"""
|
| 773 |
# 檢查緩存
|
| 774 |
if question in self.query_cache:
|
| 775 |
+
self._log("使用緩存結果")
|
| 776 |
return self.query_cache[question]
|
| 777 |
|
| 778 |
self.log_history = []
|
| 779 |
+
self._log(f"處理問題: {question}")
|
| 780 |
+
self._log(check_memory_usage())
|
| 781 |
+
|
| 782 |
+
# 0. 規則先行(命中則直接返回)
|
| 783 |
+
rb = self._rule_based_sql(question)
|
| 784 |
+
if rb:
|
| 785 |
+
self._log("規則命中,直接生成 SQL(跳過 LLM)")
|
| 786 |
+
self._log(f"最終 SQL: {rb}")
|
| 787 |
+
result = (rb, "規則生成")
|
| 788 |
+
self.query_cache[question] = result
|
| 789 |
+
gc.collect()
|
| 790 |
+
return result
|
| 791 |
+
|
| 792 |
+
# 1. 檢索相似範例
|
| 793 |
+
self._log("尋找相似範例...")
|
| 794 |
+
examples = self.find_most_similar(question, FEW_SHOT_EXAMPLES_COUNT)
|
| 795 |
+
if examples:
|
| 796 |
+
self._log(f"找到 {len(examples)} 個相似範例")
|
| 797 |
|
| 798 |
+
# 2. 建立提示詞
|
| 799 |
+
self._log("建立 Prompt...")
|
| 800 |
+
prompt = self._build_prompt(question, examples)
|
|
|
|
|
|
|
| 801 |
|
| 802 |
+
# 3. 生成 AI 回應
|
| 803 |
+
self._log("開始生成 AI 回應...")
|
| 804 |
+
response = self.huggingface_api_call(prompt)
|
| 805 |
|
| 806 |
+
# 4. 驗證/修正 SQL
|
| 807 |
+
fixed_sql, status_message = self._validate_and_fix_sql(question, response)
|
| 808 |
+
if not fixed_sql:
|
| 809 |
+
fixed_sql = "SELECT '未能生成有效的SQL,請嘗試換個問題描述';"
|
| 810 |
+
status_message = status_message or "生成失敗"
|
| 811 |
|
| 812 |
+
self._log(f"最終 SQL: {fixed_sql}")
|
| 813 |
+
result = (fixed_sql, status_message)
|
| 814 |
|
| 815 |
+
# 緩存結果
|
| 816 |
+
self.query_cache[question] = result
|
|
|
|
|
|
|
|
|
|
| 817 |
|
| 818 |
+
# 清理內存
|
| 819 |
+
gc.collect()
|
| 820 |
|
| 821 |
+
return result
|
|
|
|
|
|
|
|
|
|
| 822 |
|
| 823 |
+
# ==================== Gradio 介面與 API ====================
|
| 824 |
+
print("正在初始化 Text-to-SQL 系統...")
|
| 825 |
text_to_sql_system = TextToSQLSystem()
|
| 826 |
|
| 827 |
+
def process_query(q: str, prompt_override: str = ""):
|
| 828 |
+
if not (q or prompt_override).strip():
|
| 829 |
+
return "", "等待輸入", "請輸入問題或提供 prompt_override"
|
| 830 |
+
|
| 831 |
+
# 若提供 prompt_override:
|
| 832 |
+
if prompt_override and prompt_override.strip():
|
| 833 |
+
po = prompt_override.strip()
|
| 834 |
+
# 如果 override 本身就是 SQL,直接回傳
|
| 835 |
+
if po.upper().startswith("SELECT"):
|
| 836 |
+
if not po.strip().endswith(";"):
|
| 837 |
+
po = po.strip() + ";"
|
| 838 |
+
text_to_sql_system._log("使用 prompt_override 直接回傳 SQL")
|
| 839 |
+
logs = "\n".join(text_to_sql_system.log_history[-15:])
|
| 840 |
+
return po, "override", logs
|
| 841 |
+
# 否則當作完整 prompt 丟給 LLM
|
| 842 |
+
text_to_sql_system._log("使用 prompt_override 直接調用 LLM")
|
| 843 |
+
response = text_to_sql_system.huggingface_api_call(po)
|
| 844 |
+
fixed_sql, status_message = text_to_sql_system._validate_and_fix_sql(q or "", response)
|
| 845 |
+
if not fixed_sql:
|
| 846 |
+
fixed_sql = text_to_sql_system._generate_fallback_sql(po)
|
| 847 |
+
status_message = status_message or "override 回退"
|
| 848 |
+
text_to_sql_system._log(f"最終 SQL: {fixed_sql}")
|
| 849 |
+
logs = "\n".join(text_to_sql_system.log_history[-15:])
|
| 850 |
+
return fixed_sql, "override", logs
|
| 851 |
|
| 852 |
sql, status = text_to_sql_system.process_question(q)
|
| 853 |
+
logs = "\n".join(text_to_sql_system.log_history[-15:]) # 顯示最後15條日誌
|
|
|
|
| 854 |
return sql, status, logs
|
| 855 |
|
| 856 |
# 範例問題
|
|
|
|
| 862 |
"A組昨天完成了多少個測試項目?"
|
| 863 |
]
|
| 864 |
|
| 865 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手 (HF Space)") as demo:
|
| 866 |
+
gr.Markdown("# Text-to-SQL 智能助手 (Hugging Face Space)")
|
| 867 |
+
gr.Markdown("輸入自然語言問題,自動生成SQL查詢語句。使用 /tmp 暫存,每次啟動重新下載模型。支援桌面端透過 /predict API 呼叫。")
|
| 868 |
|
| 869 |
with gr.Row():
|
| 870 |
with gr.Column(scale=2):
|
| 871 |
+
inp = gr.Textbox(lines=3, label="您的問題", placeholder="例如:2024年每月完成多少份報告?")
|
| 872 |
+
btn = gr.Button("生成 SQL", variant="primary")
|
| 873 |
status = gr.Textbox(label="狀態", interactive=False)
|
| 874 |
+
# 隱藏的 prompt_override 供桌面端呼叫
|
| 875 |
+
prompt_override = gr.Textbox(label="prompt_override", visible=False)
|
| 876 |
|
| 877 |
with gr.Column(scale=3):
|
| 878 |
+
sql_out = gr.Code(label="生成的 SQL", language="sql", lines=8)
|
| 879 |
|
| 880 |
+
with gr.Accordion("處理日誌", open=False):
|
| 881 |
+
logs = gr.Textbox(lines=10, label="日誌", interactive=False)
|
| 882 |
|
| 883 |
# 範例區
|
| 884 |
gr.Examples(
|
| 885 |
examples=examples,
|
| 886 |
inputs=inp,
|
| 887 |
+
label="點擊試用範例問題"
|
| 888 |
)
|
| 889 |
|
| 890 |
# 綁定事件
|
| 891 |
+
btn.click(process_query, inputs=[inp, prompt_override], outputs=[sql_out, status, logs], api_name="/predict")
|
| 892 |
+
inp.submit(process_query, inputs=[inp, prompt_override], outputs=[sql_out, status, logs])
|
| 893 |
|
| 894 |
if __name__ == "__main__":
|
| 895 |
demo.launch(
|
| 896 |
server_name="0.0.0.0",
|
| 897 |
server_port=7860,
|
| 898 |
+
share=True,
|
| 899 |
+
show_error=True
|
| 900 |
)
|