Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,39 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
import json
|
| 5 |
import torch
|
| 6 |
import numpy as np
|
|
|
|
|
|
|
| 7 |
from datetime import datetime
|
| 8 |
from datasets import load_dataset
|
| 9 |
from huggingface_hub import hf_hub_download
|
| 10 |
from llama_cpp import Llama
|
| 11 |
from typing import List, Dict, Tuple, Optional
|
| 12 |
import faiss
|
| 13 |
-
|
| 14 |
|
| 15 |
-
# 使用 transformers 替代 sentence-transformers
|
| 16 |
from transformers import AutoModel, AutoTokenizer
|
| 17 |
import torch.nn.functional as F
|
| 18 |
|
| 19 |
-
# ====================
|
| 20 |
-
|
| 21 |
-
GGUF_REPO_ID = "Paul720810/gguf-models"
|
| 22 |
-
#GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q4_k_m.gguf"
|
| 23 |
GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q8_0.gguf"
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
FEW_SHOT_EXAMPLES_COUNT = 1
|
| 29 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 30 |
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 31 |
|
|
|
|
|
|
|
|
|
|
| 32 |
print("=" * 60)
|
| 33 |
-
print("🤖 Text-to-SQL
|
| 34 |
-
print(f"
|
| 35 |
-
print(f"
|
| 36 |
-
print(f"💻 設備: {DEVICE}")
|
| 37 |
print("=" * 60)
|
| 38 |
|
| 39 |
# ==================== 工具函數 ====================
|
|
@@ -41,51 +46,32 @@ def get_current_time():
|
|
| 41 |
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
| 42 |
|
| 43 |
def format_log(message: str, level: str = "INFO") -> str:
|
| 44 |
-
|
|
|
|
|
|
|
| 45 |
|
| 46 |
def parse_sql_from_response(response_text: str) -> Optional[str]:
|
| 47 |
-
|
| 48 |
-
if not response_text:
|
| 49 |
-
return None
|
| 50 |
-
|
| 51 |
-
# 清理回應文本
|
| 52 |
response_text = response_text.strip()
|
| 53 |
-
|
| 54 |
-
# 1. 先找 ```sql ... ```
|
| 55 |
match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
|
| 56 |
-
if match:
|
| 57 |
-
return match.group(1).strip()
|
| 58 |
-
|
| 59 |
-
# 2. 找任何 ``` 包圍的內容
|
| 60 |
match = re.search(r"```\s*\n?(.*?)\n?```", response_text, re.DOTALL)
|
| 61 |
if match:
|
| 62 |
sql_candidate = match.group(1).strip()
|
| 63 |
-
if sql_candidate.upper().startswith('SELECT'):
|
| 64 |
-
return sql_candidate
|
| 65 |
-
|
| 66 |
-
# 3. 找 SQL 語句(更寬鬆的匹配)
|
| 67 |
match = re.search(r"(SELECT\s+.*?;)", response_text, re.DOTALL | re.IGNORECASE)
|
| 68 |
-
if match:
|
| 69 |
-
return match.group(1).strip()
|
| 70 |
-
|
| 71 |
-
# 4. 找沒有分號的 SQL
|
| 72 |
match = re.search(r"(SELECT\s+.*?)(?=\n\n|\n```|$|\n[^,\s])", response_text, re.DOTALL | re.IGNORECASE)
|
| 73 |
if match:
|
| 74 |
sql = match.group(1).strip()
|
| 75 |
-
if not sql.endswith(';'):
|
| 76 |
-
sql += ';'
|
| 77 |
return sql
|
| 78 |
-
|
| 79 |
-
# 5. 如果包含 SELECT,嘗試提取整行
|
| 80 |
if 'SELECT' in response_text.upper():
|
| 81 |
-
|
| 82 |
-
for line in lines:
|
| 83 |
line = line.strip()
|
| 84 |
if line.upper().startswith('SELECT'):
|
| 85 |
-
if not line.endswith(';'):
|
| 86 |
-
line += ';'
|
| 87 |
return line
|
| 88 |
-
|
| 89 |
return None
|
| 90 |
|
| 91 |
# ==================== Text-to-SQL 核心類 ====================
|
|
@@ -94,445 +80,179 @@ class TextToSQLSystem:
|
|
| 94 |
self.log_history = []
|
| 95 |
self._log("初始化系統...")
|
| 96 |
self.query_cache = {}
|
| 97 |
-
|
| 98 |
-
# 1. 載入嵌入模型
|
| 99 |
-
self._log(f"載入嵌入模型: {embed_model_name}")
|
| 100 |
-
self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
|
| 101 |
-
self.embed_model = AutoModel.from_pretrained(embed_model_name)
|
| 102 |
-
if DEVICE == "cuda":
|
| 103 |
-
self.embed_model = self.embed_model.cuda()
|
| 104 |
-
|
| 105 |
-
# 2. 載入數據庫結構
|
| 106 |
-
self.schema = self._load_schema()
|
| 107 |
-
|
| 108 |
-
# 3. 載入數據集並建立索引
|
| 109 |
-
self.dataset, self.faiss_index = self._load_and_index_dataset()
|
| 110 |
-
|
| 111 |
-
# 4. 載入 GGUF 模型(添加錯誤處理)
|
| 112 |
-
self._load_gguf_model()
|
| 113 |
-
|
| 114 |
-
self._log("✅ 系統初始化完成")
|
| 115 |
-
# 載入數據庫結構
|
| 116 |
-
self.schema = self._load_schema()
|
| 117 |
-
|
| 118 |
-
# 暫時添加:打印 schema 信息
|
| 119 |
-
if self.schema:
|
| 120 |
-
print("=" * 50)
|
| 121 |
-
print("數據庫 Schema 信息:")
|
| 122 |
-
for table_name, columns in self.schema.items():
|
| 123 |
-
print(f"\n表格: {table_name}")
|
| 124 |
-
print(f"欄位數: {len(columns)}")
|
| 125 |
-
print("欄位列表:")
|
| 126 |
-
for col in columns[:5]: # 只顯示前5個
|
| 127 |
-
print(f" - {col['name']} ({col['type']})")
|
| 128 |
-
print("=" * 50)
|
| 129 |
-
|
| 130 |
-
# in class TextToSQLSystem:
|
| 131 |
-
|
| 132 |
-
def _load_gguf_model(self):
|
| 133 |
-
"""載入 GGUF 模型,使用更穩定、簡潔的參數"""
|
| 134 |
try:
|
| 135 |
-
self._log("
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
self.
|
| 144 |
-
|
| 145 |
-
n_ctx=2048, # 將上下文增加到 2048 以確保 Prompt 不會超長
|
| 146 |
-
n_threads=4, # 保持 4 線程
|
| 147 |
-
n_batch=512, # 建議值
|
| 148 |
-
verbose=False, # 設為 False 避免 llama.cpp 本身的日誌干擾
|
| 149 |
-
n_gpu_layers=0 # 確認在 CPU 上運行
|
| 150 |
-
)
|
| 151 |
-
|
| 152 |
-
# 簡單測試模型是否能回應
|
| 153 |
-
self.llm("你好", max_tokens=3)
|
| 154 |
-
self._log("✅ GGUF 模型載入成功")
|
| 155 |
-
|
| 156 |
except Exception as e:
|
| 157 |
-
self._log(f"❌
|
| 158 |
-
self._log(
|
| 159 |
self.llm = None
|
| 160 |
|
| 161 |
-
def
|
| 162 |
-
|
| 163 |
-
try:
|
| 164 |
-
model_path = hf_hub_download(
|
| 165 |
-
repo_id=GGUF_REPO_ID,
|
| 166 |
-
filename=GGUF_FILENAME,
|
| 167 |
-
repo_type="dataset"
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
self.llm = Llama(
|
| 171 |
-
model_path=model_path,
|
| 172 |
-
n_ctx=512,
|
| 173 |
-
n_threads=4,
|
| 174 |
-
verbose=False,
|
| 175 |
-
n_gpu_layers=0
|
| 176 |
-
)
|
| 177 |
-
|
| 178 |
-
# 測試生成
|
| 179 |
-
test_result = self.llm("SELECT", max_tokens=5)
|
| 180 |
-
self._log("✅ GGUF 模型載入成功")
|
| 181 |
-
return True
|
| 182 |
-
|
| 183 |
-
except Exception as e:
|
| 184 |
-
self._log(f"GGUF 載入失敗: {e}", "WARNING")
|
| 185 |
-
return False
|
| 186 |
|
| 187 |
-
def
|
| 188 |
-
"""使用 Transformers 載入你的微調模型"""
|
| 189 |
try:
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
self.
|
| 194 |
-
|
| 195 |
-
# 載入你的微調模型
|
| 196 |
-
self.transformers_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL_PATH)
|
| 197 |
-
self.transformers_model = AutoModelForCausalLM.from_pretrained(
|
| 198 |
-
FINETUNED_MODEL_PATH,
|
| 199 |
-
torch_dtype=torch.float32, # CPU 使用 float32
|
| 200 |
-
device_map="cpu", # 強制使用 CPU
|
| 201 |
-
trust_remote_code=True # Qwen 模型可能需要
|
| 202 |
-
)
|
| 203 |
-
|
| 204 |
-
# 創建生成管道
|
| 205 |
-
self.generation_pipeline = pipeline(
|
| 206 |
-
"text-generation",
|
| 207 |
-
model=self.transformers_model,
|
| 208 |
-
tokenizer=self.transformers_tokenizer,
|
| 209 |
-
device=-1, # CPU
|
| 210 |
-
max_length=512,
|
| 211 |
-
do_sample=True,
|
| 212 |
-
temperature=0.1,
|
| 213 |
-
top_p=0.9,
|
| 214 |
-
pad_token_id=self.transformers_tokenizer.eos_token_id
|
| 215 |
-
)
|
| 216 |
-
|
| 217 |
-
self.llm = "transformers" # 標記使用 transformers
|
| 218 |
-
self._log("✅ Transformers 模型載入成功")
|
| 219 |
-
|
| 220 |
except Exception as e:
|
| 221 |
-
self._log(f"❌
|
| 222 |
self.llm = None
|
| 223 |
|
| 224 |
def huggingface_api_call(self, prompt: str) -> str:
|
| 225 |
-
|
| 226 |
-
if self.llm is None:
|
| 227 |
-
self._log("模型未載入,返回 fallback SQL。", "ERROR")
|
| 228 |
-
return self._generate_fallback_sql(prompt)
|
| 229 |
-
|
| 230 |
try:
|
| 231 |
-
output = self.llm(
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
top_p=0.9,
|
| 236 |
-
echo=False,
|
| 237 |
-
# --- 將 stop 參數加回來 ---
|
| 238 |
-
stop=["```", ";", "\n\n", "</s>"],
|
| 239 |
-
)
|
| 240 |
-
|
| 241 |
-
self._log(f"🧠 模型原始輸出 (Raw Output): {output}", "DEBUG")
|
| 242 |
-
|
| 243 |
-
if output and "choices" in output and len(output["choices"]) > 0:
|
| 244 |
-
generated_text = output["choices"][0]["text"]
|
| 245 |
-
self._log(f"📝 提取出的生成文本: {generated_text.strip()}", "DEBUG")
|
| 246 |
-
return generated_text.strip()
|
| 247 |
-
else:
|
| 248 |
-
self._log("❌ 模型的原始輸出格式不正確或為空。", "ERROR")
|
| 249 |
-
return ""
|
| 250 |
-
|
| 251 |
except Exception as e:
|
| 252 |
-
self._log(f"❌
|
| 253 |
-
import traceback
|
| 254 |
-
self._log(traceback.format_exc(), "DEBUG")
|
| 255 |
return ""
|
| 256 |
-
|
| 257 |
-
def _load_gguf_model_fallback(self, model_path):
|
| 258 |
-
"""備用載入方式"""
|
| 259 |
-
try:
|
| 260 |
-
# 嘗試不同的參數組合
|
| 261 |
-
self.llm = Llama(
|
| 262 |
-
model_path=model_path,
|
| 263 |
-
n_ctx=512, # 更小的上下文
|
| 264 |
-
n_threads=4,
|
| 265 |
-
n_batch=128,
|
| 266 |
-
vocab_only=False,
|
| 267 |
-
use_mmap=True,
|
| 268 |
-
use_mlock=False,
|
| 269 |
-
verbose=True
|
| 270 |
-
)
|
| 271 |
-
self._log("✅ 備用方式載入成功")
|
| 272 |
-
except Exception as e:
|
| 273 |
-
self._log(f"❌ 備用方式也失敗: {e}", "ERROR")
|
| 274 |
-
self.llm = None
|
| 275 |
-
|
| 276 |
-
def _log(self, message: str, level: str = "INFO"):
|
| 277 |
-
self.log_history.append(format_log(message, level))
|
| 278 |
-
print(format_log(message, level))
|
| 279 |
-
|
| 280 |
def _load_schema(self) -> Dict:
|
| 281 |
-
"""載入數據庫結構"""
|
| 282 |
try:
|
| 283 |
-
schema_path = hf_hub_download(
|
| 284 |
-
repo_id=DATASET_REPO_ID,
|
| 285 |
-
filename="sqlite_schema_FULL.json",
|
| 286 |
-
repo_type="dataset"
|
| 287 |
-
)
|
| 288 |
with open(schema_path, "r", encoding="utf-8") as f:
|
| 289 |
schema_data = json.load(f)
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格:")
|
| 293 |
-
for table_name, columns in schema_data.items():
|
| 294 |
-
self._log(f" - {table_name}: {len(columns)} 個欄位")
|
| 295 |
-
# 顯示前3個欄位作為範例
|
| 296 |
-
sample_cols = [col['name'] for col in columns[:3]]
|
| 297 |
-
self._log(f" 範例欄位: {', '.join(sample_cols)}")
|
| 298 |
-
|
| 299 |
-
self._log("✅ 數據庫結構載入完成")
|
| 300 |
-
return schema_data
|
| 301 |
-
|
| 302 |
except Exception as e:
|
| 303 |
self._log(f"❌ 載入 schema 失敗: {e}", "ERROR")
|
| 304 |
return {}
|
| 305 |
|
| 306 |
-
# 也可以添加一個方法來檢查生成的 SQL 是否使用了正確的表格和欄位
|
| 307 |
-
def _analyze_sql_correctness(self, sql: str) -> Dict:
|
| 308 |
-
"""分析 SQL 的正確性"""
|
| 309 |
-
analysis = {
|
| 310 |
-
'valid_tables': [],
|
| 311 |
-
'invalid_tables': [],
|
| 312 |
-
'valid_columns': [],
|
| 313 |
-
'invalid_columns': [],
|
| 314 |
-
'suggestions': []
|
| 315 |
-
}
|
| 316 |
-
|
| 317 |
-
if not self.schema:
|
| 318 |
-
return analysis
|
| 319 |
-
|
| 320 |
-
# 提取 SQL 中的表格名稱
|
| 321 |
-
table_pattern = r'FROM\s+(\w+)|JOIN\s+(\w+)'
|
| 322 |
-
table_matches = re.findall(table_pattern, sql, re.IGNORECASE)
|
| 323 |
-
used_tables = [match[0] or match[1] for match in table_matches]
|
| 324 |
-
|
| 325 |
-
# 檢查表格是否存在
|
| 326 |
-
valid_tables = list(self.schema.keys())
|
| 327 |
-
for table in used_tables:
|
| 328 |
-
if table in valid_tables:
|
| 329 |
-
analysis['valid_tables'].append(table)
|
| 330 |
-
else:
|
| 331 |
-
analysis['invalid_tables'].append(table)
|
| 332 |
-
# 尋找相似的表格名稱
|
| 333 |
-
for valid_table in valid_tables:
|
| 334 |
-
if table.lower() in valid_table.lower() or valid_table.lower() in table.lower():
|
| 335 |
-
analysis['suggestions'].append(f"{table} -> {valid_table}")
|
| 336 |
-
|
| 337 |
-
# 提取欄位名稱(簡單版本)
|
| 338 |
-
column_pattern = r'SELECT\s+(.*?)\s+FROM|WHERE\s+(\w+)\s*[=<>]|GROUP BY\s+(\w+)|ORDER BY\s+(\w+)'
|
| 339 |
-
column_matches = re.findall(column_pattern, sql, re.IGNORECASE)
|
| 340 |
-
|
| 341 |
-
return analysis
|
| 342 |
-
|
| 343 |
def _encode_texts(self, texts):
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
texts = [texts]
|
| 347 |
-
|
| 348 |
-
inputs = self.embed_tokenizer(texts, padding=True, truncation=True,
|
| 349 |
-
return_tensors="pt", max_length=512)
|
| 350 |
-
if DEVICE == "cuda":
|
| 351 |
-
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 352 |
-
|
| 353 |
with torch.no_grad():
|
| 354 |
outputs = self.embed_model(**inputs)
|
| 355 |
-
|
| 356 |
-
# 使用平均池化
|
| 357 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 358 |
return embeddings.cpu()
|
| 359 |
|
| 360 |
def _load_and_index_dataset(self):
|
| 361 |
-
"""載入數據集並建立 FAISS 索引"""
|
| 362 |
try:
|
| 363 |
dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
try:
|
| 367 |
-
original_count = len(dataset)
|
| 368 |
-
except Exception:
|
| 369 |
-
original_count = None
|
| 370 |
-
|
| 371 |
-
dataset = dataset.filter(
|
| 372 |
-
lambda ex: isinstance(ex.get("messages"), list)
|
| 373 |
-
and len(ex["messages"]) >= 2
|
| 374 |
-
and all(
|
| 375 |
-
isinstance(m.get("content"), str) and m.get("content") and m["content"].strip()
|
| 376 |
-
for m in ex["messages"][:2]
|
| 377 |
-
)
|
| 378 |
-
)
|
| 379 |
-
|
| 380 |
-
if original_count is not None:
|
| 381 |
-
self._log(
|
| 382 |
-
f"資料集清理: 原始 {original_count} 筆, 過濾後 {len(dataset)} 筆, 移除 {original_count - len(dataset)} 筆"
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
-
if len(dataset) == 0:
|
| 386 |
-
self._log("清理後資料集為空,無法建立索引。", "ERROR")
|
| 387 |
-
return None, None
|
| 388 |
-
|
| 389 |
-
corpus = [item['messages'][0]['content'] for item in dataset]
|
| 390 |
self._log(f"正在編碼 {len(corpus)} 個問題...")
|
| 391 |
-
|
| 392 |
-
# 批量編碼
|
| 393 |
-
embeddings_list = []
|
| 394 |
-
batch_size = 32
|
| 395 |
-
|
| 396 |
-
for i in range(0, len(corpus), batch_size):
|
| 397 |
-
batch_texts = corpus[i:i+batch_size]
|
| 398 |
-
batch_embeddings = self._encode_texts(batch_texts)
|
| 399 |
-
embeddings_list.append(batch_embeddings)
|
| 400 |
-
self._log(f"已編碼 {min(i+batch_size, len(corpus))}/{len(corpus)}")
|
| 401 |
-
|
| 402 |
-
all_embeddings = torch.cat(embeddings_list, dim=0).numpy()
|
| 403 |
-
|
| 404 |
-
# 建立 FAISS 索引
|
| 405 |
index = faiss.IndexFlatIP(all_embeddings.shape[1])
|
| 406 |
index.add(all_embeddings.astype('float32'))
|
| 407 |
-
|
| 408 |
self._log("✅ 向量索引建立完成")
|
| 409 |
return dataset, index
|
| 410 |
-
|
| 411 |
except Exception as e:
|
| 412 |
self._log(f"❌ 載入數據失敗: {e}", "ERROR")
|
|
|
|
| 413 |
return None, None
|
| 414 |
-
|
| 415 |
def _identify_relevant_tables(self, question: str) -> List[str]:
|
| 416 |
-
"""根據實際 Schema 識別相關表格"""
|
| 417 |
question_lower = question.lower()
|
| 418 |
relevant_tables = []
|
| 419 |
-
|
| 420 |
-
# 根據實際表格的關鍵詞映射
|
| 421 |
-
keyword_to_table = {
|
| 422 |
-
'TSR53SampleDescription': ['客戶', '買方', '申請', '發票對象', 'customer', 'invoice', 'sample'],
|
| 423 |
-
'JobsInProgress': ['進行中', '買家', '申請方', 'buyer', 'applicant', 'progress', '工作狀態'],
|
| 424 |
-
'JobTimeline': ['時間', '完成', '創建', '實驗室', 'timeline', 'creation', 'lab'],
|
| 425 |
-
'TSR53Invoice': ['發票', '金額', '費用', 'invoice', 'credit', 'amount'],
|
| 426 |
-
'JobEventsLog': ['事件', '操作', '用戶', 'event', 'log', 'user'],
|
| 427 |
-
'calendar_days': ['工作日', '假期', 'workday', 'holiday', 'calendar']
|
| 428 |
-
}
|
| 429 |
-
|
| 430 |
for table, keywords in keyword_to_table.items():
|
| 431 |
-
if any(keyword in question_lower for keyword in keywords):
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
# 預設重要表格
|
| 435 |
-
if not relevant_tables:
|
| 436 |
-
if any(word in question_lower for word in ['客戶', '買家', '申請', '工作單', '數量']):
|
| 437 |
-
return ['TSR53SampleDescription', 'JobsInProgress']
|
| 438 |
-
else:
|
| 439 |
-
return ['JobTimeline', 'TSR53SampleDescription']
|
| 440 |
-
|
| 441 |
-
return relevant_tables[:3] # 最多返回3個相關表格
|
| 442 |
-
|
| 443 |
-
# 請將這整個函數複製到您的 TextToSQLSystem class 內部
|
| 444 |
|
| 445 |
def _format_relevant_schema(self, table_names: List[str]) -> str:
|
| 446 |
-
""
|
| 447 |
-
生成一個簡化的、不易被模型錯誤模仿的 Schema 字符串。
|
| 448 |
-
"""
|
| 449 |
-
if not self.schema:
|
| 450 |
-
return "No schema available.\n"
|
| 451 |
-
|
| 452 |
-
actual_table_names_map = {name.lower(): name for name in self.schema.keys()}
|
| 453 |
-
real_table_names = []
|
| 454 |
-
for table in table_names:
|
| 455 |
-
actual_name = actual_table_names_map.get(table.lower())
|
| 456 |
-
if actual_name:
|
| 457 |
-
real_table_names.append(actual_name)
|
| 458 |
-
elif table in self.schema:
|
| 459 |
-
real_table_names.append(table)
|
| 460 |
-
|
| 461 |
-
if not real_table_names:
|
| 462 |
-
self._log("未識別到相關表格,使用預設核心表格。", "WARNING")
|
| 463 |
-
real_table_names = ['TSR53SampleDescription', 'JobTimeline', 'JobsInProgress']
|
| 464 |
-
|
| 465 |
formatted = ""
|
| 466 |
-
for table in
|
| 467 |
if table in self.schema:
|
| 468 |
-
# 使用簡單的 "Table: ..." 和 "Columns: ..." 格式
|
| 469 |
formatted += f"Table: {table}\n"
|
| 470 |
cols_str = []
|
| 471 |
-
# 只顯示前 10 個關鍵欄位
|
| 472 |
for col in self.schema[table][:10]:
|
| 473 |
-
col_name = col['name']
|
| 474 |
-
col_type
|
| 475 |
-
|
| 476 |
-
# 將描述信息放在括號裡
|
| 477 |
-
if col_desc:
|
| 478 |
-
cols_str.append(f"{col_name} ({col_type}, {col_desc})")
|
| 479 |
-
else:
|
| 480 |
-
cols_str.append(f"{col_name} ({col_type})")
|
| 481 |
formatted += f"Columns: {', '.join(cols_str)}\n\n"
|
| 482 |
-
|
| 483 |
return formatted.strip()
|
| 484 |
|
| 485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
|
| 487 |
def _validate_and_fix_sql(self, question: str, raw_response: str) -> Tuple[Optional[str], str]:
|
| 488 |
-
"""
|
| 489 |
-
(V23 / 统一实体识别版)
|
| 490 |
-
一個全面、多層次的 SQL 驗證與生成引擎。
|
| 491 |
-
引入了全新的、统一的实体识别引擎,能够准确解析 "买家 Gap", "c0761n",
|
| 492 |
-
"买家ID c0761n" 等多种复杂的实体提问模式。
|
| 493 |
-
"""
|
| 494 |
q_lower = question.lower()
|
| 495 |
-
|
| 496 |
-
# ==============================================================================
|
| 497 |
-
# 第一層:高價值意圖識別與模板覆寫 (Intent Recognition & Templating)
|
| 498 |
-
# ==============================================================================
|
| 499 |
-
|
| 500 |
-
# --- **全新的统一实体识别引擎** ---
|
| 501 |
entity_match_data = None
|
| 502 |
-
|
| 503 |
-
# 定义多种识别模式,【优先级从高到低】
|
| 504 |
entity_patterns = [
|
| 505 |
-
# 模式1: 匹配 "类型 + ID" (e.g., "买家ID C0761N") - 最高优先级
|
| 506 |
{'pattern': r"(买家|buyer)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.BuyerID', 'type': '买家ID'},
|
| 507 |
{'pattern': r"(申请方|申请厂商|applicant)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.ApplicantID', 'type': '申请方ID'},
|
| 508 |
{'pattern': r"(付款方|付款厂商|invoiceto)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.InvoiceToID', 'type': '付款方ID'},
|
| 509 |
{'pattern': r"(代理商|agent)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.AgentID', 'type': '代理商ID'},
|
| 510 |
-
|
| 511 |
-
# 模式2: 匹配 "类型 + 名称" (e.g., "买家 Gap")
|
| 512 |
{'pattern': r"(买家|buyer|客戶)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.BuyerName', 'type': '买家'},
|
| 513 |
{'pattern': r"(申请方|申请厂商|applicant)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.ApplicantName', 'type': '申请方'},
|
| 514 |
{'pattern': r"(付款方|付款厂商|invoiceto)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.InvoiceToName', 'type': '付款方'},
|
| 515 |
{'pattern': r"(代理商|agent)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.AgentName', 'type': '代理商'},
|
| 516 |
-
|
| 517 |
-
# 模式3: 单独匹配一个 ID (e.g., "c0761n") - 较低优先级
|
| 518 |
{'pattern': r"\b([A-Z]\d{4}[A-Z])\b", 'column': 'sd.ApplicantID', 'type': 'ID'}
|
| 519 |
]
|
| 520 |
-
|
| 521 |
for p in entity_patterns:
|
| 522 |
match = re.search(p['pattern'], question, re.IGNORECASE)
|
| 523 |
if match:
|
| 524 |
entity_value = match.group(2) if len(match.groups()) > 1 else match.group(1)
|
| 525 |
-
entity_match_data = {
|
| 526 |
-
"type": p['type'],
|
| 527 |
-
"name": entity_value.strip().upper(),
|
| 528 |
-
"column": p['column']
|
| 529 |
-
}
|
| 530 |
break
|
| 531 |
|
| 532 |
-
# --- 预先检测其他意图 ---
|
| 533 |
-
job_no_match = re.search(r"(?:工單|jobno)\s*'\"?([A-Z]{2,3}\d+)'\"?", question, re.IGNORECASE)
|
| 534 |
-
|
| 535 |
-
# --- 判断逻辑: 依优先级进入对应的模板 ---
|
| 536 |
if any(kw in q_lower for kw in ['報告號碼', '報告清單', '列出報告', 'report number', 'list of reports']):
|
| 537 |
year_match = re.search(r'(\d{4})\s*年?', question)
|
| 538 |
month_match = re.search(r'(\d{1,2})\s*月', question)
|
|
@@ -540,259 +260,71 @@ class TextToSQLSystem:
|
|
| 540 |
select_clause = "SELECT jt.JobNo, jt.ReportAuthorization"
|
| 541 |
where_conditions = ["jt.ReportAuthorization IS NOT NULL"]
|
| 542 |
log_parts = []
|
| 543 |
-
|
| 544 |
-
if
|
| 545 |
-
if month_match: month = month_match.group(1).zfill(2); where_conditions.append(f"strftime('%m', jt.ReportAuthorization) = '{month}'"); log_parts.append(f"{month}月")
|
| 546 |
-
|
| 547 |
if 'fail' in q_lower or '失敗' in q_lower:
|
| 548 |
-
if "JOIN TSR53SampleDescription" not in from_clause: from_clause
|
| 549 |
where_conditions.append("sd.OverallRating = 'Fail'"); log_parts.append("Fail")
|
| 550 |
elif 'pass' in q_lower or '通過' in q_lower:
|
| 551 |
-
if "JOIN TSR53SampleDescription" not in from_clause: from_clause
|
| 552 |
where_conditions.append("sd.OverallRating = 'Pass'"); log_parts.append("Pass")
|
| 553 |
-
|
| 554 |
if entity_match_data:
|
| 555 |
entity_name, column_name = entity_match_data["name"], entity_match_data["column"]
|
| 556 |
-
if "JOIN TSR53SampleDescription" not in from_clause: from_clause
|
| 557 |
match_operator = "=" if column_name.endswith("ID") else "LIKE"
|
| 558 |
entity_value = f"'{entity_name}'" if match_operator == "=" else f"'%{entity_name}%'"
|
| 559 |
where_conditions.append(f"{column_name} {match_operator} {entity_value}")
|
| 560 |
log_parts.append(entity_name)
|
| 561 |
select_clause = "SELECT jt.JobNo, sd.BuyerName, jt.ReportAuthorization"
|
| 562 |
-
|
| 563 |
-
final_where_clause = "WHERE " + " AND ".join(where_conditions)
|
| 564 |
time_log = " ".join(log_parts) if log_parts else "全部"
|
| 565 |
self._log(f"🔄 檢測到查詢【{time_log} 報告列表】意圖,啟用智能模板。", "INFO")
|
| 566 |
template_sql = f"{select_clause} {from_clause} {final_where_clause} ORDER BY jt.ReportAuthorization DESC;"
|
| 567 |
return self._finalize_sql(template_sql, f"模板覆寫: {time_log} 報告列表查詢")
|
| 568 |
|
| 569 |
-
|
| 570 |
-
elif '報告' in q_lower and any(kw in q_lower for kw in ['幾份', '多少', '數量', '總數']) and not entity_match_data:
|
| 571 |
year_match = re.search(r'(\d{4})\s*年?', question)
|
| 572 |
time_condition, time_log = "", "總"
|
| 573 |
if year_match:
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
time_log = f"{year}年"
|
| 577 |
else:
|
| 578 |
time_condition = "WHERE ReportAuthorization IS NOT NULL"
|
| 579 |
self._log(f"🔄 檢測到查詢【{time_log}全局報告總數】意圖,啟用模板。", "INFO")
|
| 580 |
template_sql = f"SELECT COUNT(DISTINCT JobNo) AS report_count FROM JobTimeline {time_condition};"
|
| 581 |
return self._finalize_sql(template_sql, f"模板覆寫: {time_log}全局報告總數查詢")
|
| 582 |
|
| 583 |
-
# ==============================================================================
|
| 584 |
-
# 第二层:常规修正流程 (Fallback Corrections)
|
| 585 |
-
# ==============================================================================
|
| 586 |
self._log("未觸發任何模板,嘗試解析並修正 AI 輸出...", "INFO")
|
| 587 |
-
|
| 588 |
parsed_sql = parse_sql_from_response(raw_response)
|
| 589 |
if not parsed_sql:
|
| 590 |
-
self._log(f"❌ 未能從模型回應中解析出任何 SQL。原始回應: {raw_response}", "ERROR")
|
| 591 |
return None, f"無法解析SQL。原始回應:\n{raw_response}"
|
| 592 |
-
|
| 593 |
-
self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG")
|
| 594 |
-
|
| 595 |
fixed_sql = " " + parsed_sql.strip() + " "
|
| 596 |
fixes_applied_fallback = []
|
| 597 |
-
|
| 598 |
dialect_corrections = {r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)"}
|
| 599 |
-
for
|
| 600 |
-
if re.search(
|
| 601 |
-
fixed_sql = re.sub(
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
for wrong, correct in schema_corrections.items():
|
| 606 |
-
pattern = r'\b' + re.escape(wrong) + r'\b'
|
| 607 |
if re.search(pattern, fixed_sql, re.IGNORECASE):
|
| 608 |
-
fixed_sql = re.sub(pattern,
|
| 609 |
-
fixes_applied_fallback.append(f"映射 Schema: '{wrong}' -> '{correct}'")
|
| 610 |
-
|
| 611 |
log_msg = "AI 生成並成功修正" if fixes_applied_fallback else "AI 生成且無需修正"
|
| 612 |
return self._finalize_sql(fixed_sql, log_msg)
|
| 613 |
|
| 614 |
-
def _finalize_sql(self, sql: str, log_message: str) -> Tuple[str, str]:
|
| 615 |
-
"""一個輔助函數,用於清理最終的SQL並記錄成功日誌。"""
|
| 616 |
-
final_sql = sql.strip()
|
| 617 |
-
if not final_sql.endswith(';'):
|
| 618 |
-
final_sql += ';'
|
| 619 |
-
final_sql = re.sub(r'\s+', ' ', final_sql).strip()
|
| 620 |
-
self._log(f"✅ SQL 已生成 ({log_message})", "INFO")
|
| 621 |
-
self._log(f" - 最終 SQL: {final_sql}", "DEBUG")
|
| 622 |
-
return final_sql, "生成成功"
|
| 623 |
-
|
| 624 |
-
def find_most_similar(self, question: str, top_k: int) -> List[Dict]:
|
| 625 |
-
"""使用 FAISS 快速檢索相似問題"""
|
| 626 |
-
if self.faiss_index is None or self.dataset is None:
|
| 627 |
-
return []
|
| 628 |
-
|
| 629 |
-
try:
|
| 630 |
-
# 編碼問題
|
| 631 |
-
q_embedding = self._encode_texts([question]).numpy().astype('float32')
|
| 632 |
-
|
| 633 |
-
# FAISS 搜索
|
| 634 |
-
distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset)))
|
| 635 |
-
|
| 636 |
-
results = []
|
| 637 |
-
seen_questions = set()
|
| 638 |
-
|
| 639 |
-
for i, idx in enumerate(indices[0]):
|
| 640 |
-
if len(results) >= top_k:
|
| 641 |
-
break
|
| 642 |
-
|
| 643 |
-
# 修復:將 numpy.int64 轉換為 Python int
|
| 644 |
-
idx = int(idx) # ← 添加這行轉換
|
| 645 |
-
|
| 646 |
-
if idx >= len(self.dataset): # 確保索引有效
|
| 647 |
-
continue
|
| 648 |
-
|
| 649 |
-
item = self.dataset[idx]
|
| 650 |
-
# 防呆:若樣本不完整則跳過
|
| 651 |
-
if not isinstance(item.get('messages'), list) or len(item['messages']) < 2:
|
| 652 |
-
continue
|
| 653 |
-
q_content = (item['messages'][0].get('content') or '').strip()
|
| 654 |
-
a_content = (item['messages'][1].get('content') or '').strip()
|
| 655 |
-
if not q_content or not a_content:
|
| 656 |
-
continue
|
| 657 |
-
|
| 658 |
-
# 提取純淨問題
|
| 659 |
-
clean_q = re.sub(r"以下是一個SQL查詢任務:\s*指令:\s*", "", q_content).strip()
|
| 660 |
-
if clean_q in seen_questions:
|
| 661 |
-
continue
|
| 662 |
-
|
| 663 |
-
seen_questions.add(clean_q)
|
| 664 |
-
sql = parse_sql_from_response(a_content) or "無法解析範例SQL"
|
| 665 |
-
|
| 666 |
-
results.append({
|
| 667 |
-
"similarity": float(distances[0][i]),
|
| 668 |
-
"question": clean_q,
|
| 669 |
-
"sql": sql
|
| 670 |
-
})
|
| 671 |
-
|
| 672 |
-
return results
|
| 673 |
-
|
| 674 |
-
except Exception as e:
|
| 675 |
-
self._log(f"❌ 檢索失敗: {e}", "ERROR")
|
| 676 |
-
return []
|
| 677 |
-
|
| 678 |
-
# in class TextToSQLSystem:
|
| 679 |
-
|
| 680 |
-
def _build_prompt(self, user_q: str, examples: List[Dict]) -> str:
|
| 681 |
-
"""
|
| 682 |
-
建立一個高度結構化、以任務為導向的提示詞,使用清晰的標題分隔符。
|
| 683 |
-
"""
|
| 684 |
-
relevant_tables = self._identify_relevant_tables(user_q)
|
| 685 |
-
|
| 686 |
-
# 使用我們新的、更簡單的 schema 格式化函數
|
| 687 |
-
schema_str = self._format_relevant_schema(relevant_tables)
|
| 688 |
-
|
| 689 |
-
example_str = "No example available."
|
| 690 |
-
if examples:
|
| 691 |
-
best_example = examples[0]
|
| 692 |
-
example_str = f"Question: {best_example['question']}\nSQL:\n```sql\n{best_example['sql']}\n```"
|
| 693 |
-
|
| 694 |
-
# 使用強分隔符和清晰的標題來構建 prompt
|
| 695 |
-
prompt = f"""### INSTRUCTIONS ###
|
| 696 |
-
You are a SQLite expert. Your only job is to generate a single, valid SQLite query based on the provided schema and question.
|
| 697 |
-
- ONLY use the tables and columns from the schema below.
|
| 698 |
-
- ALWAYS use SQLite syntax (e.g., `strftime('%Y', date_column)` for years).
|
| 699 |
-
- The report completion date is the `ReportAuthorization` column in the `JobTimeline` table.
|
| 700 |
-
- Your output MUST be ONLY the SQL query inside a ```sql code block.
|
| 701 |
-
|
| 702 |
-
### SCHEMA ###
|
| 703 |
-
{schema_str}
|
| 704 |
-
|
| 705 |
-
### EXAMPLE ###
|
| 706 |
-
{example_str}
|
| 707 |
-
|
| 708 |
-
### TASK ###
|
| 709 |
-
Generate a SQLite query for the following question.
|
| 710 |
-
Question: {user_q}
|
| 711 |
-
SQL:
|
| 712 |
-
```sql
|
| 713 |
-
"""
|
| 714 |
-
self._log(f"📏 Prompt 長度: {len(prompt)} 字符")
|
| 715 |
-
# 不再需要複雜的長度截斷邏輯,因為 schema 已經被簡化
|
| 716 |
-
return prompt
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
def _generate_fallback_sql(self, prompt: str) -> str:
|
| 720 |
-
"""當模型不可用時的備用 SQL 生成"""
|
| 721 |
-
prompt_lower = prompt.lower()
|
| 722 |
-
|
| 723 |
-
# 簡單的關鍵詞匹配生成基本 SQL
|
| 724 |
-
if "統計" in prompt or "數量" in prompt or "多少" in prompt:
|
| 725 |
-
if "月" in prompt:
|
| 726 |
-
return "SELECT strftime('%Y-%m', completed_time) as month, COUNT(*) as count FROM jobtimeline GROUP BY month ORDER BY month;"
|
| 727 |
-
elif "客戶" in prompt:
|
| 728 |
-
return "SELECT applicant, COUNT(*) as count FROM tsr53sampledescription GROUP BY applicant ORDER BY count DESC;"
|
| 729 |
-
else:
|
| 730 |
-
return "SELECT COUNT(*) as total_count FROM jobtimeline WHERE completed_time IS NOT NULL;"
|
| 731 |
-
|
| 732 |
-
elif "金額" in prompt or "總額" in prompt:
|
| 733 |
-
return "SELECT SUM(amount) as total_amount FROM tsr53invoice;"
|
| 734 |
-
|
| 735 |
-
elif "評級" in prompt or "pass" in prompt_lower or "fail" in prompt_lower:
|
| 736 |
-
return "SELECT rating, COUNT(*) as count FROM tsr53sampledescription GROUP BY rating;"
|
| 737 |
-
|
| 738 |
-
else:
|
| 739 |
-
return "SELECT * FROM jobtimeline LIMIT 10;"
|
| 740 |
-
|
| 741 |
-
def _validate_model_file(self, model_path):
|
| 742 |
-
"""驗證模型檔案完整性"""
|
| 743 |
-
try:
|
| 744 |
-
if not os.path.exists(model_path):
|
| 745 |
-
return False
|
| 746 |
-
|
| 747 |
-
# 檢查檔案大小(至少應該有幾MB)
|
| 748 |
-
file_size = os.path.getsize(model_path)
|
| 749 |
-
if file_size < 10 * 1024 * 1024: # 小於 10MB 可能有問題
|
| 750 |
-
return False
|
| 751 |
-
|
| 752 |
-
# 檢查 GGUF 檔案頭部
|
| 753 |
-
with open(model_path, 'rb') as f:
|
| 754 |
-
header = f.read(8)
|
| 755 |
-
if not header.startswith(b'GGUF'):
|
| 756 |
-
return False
|
| 757 |
-
|
| 758 |
-
return True
|
| 759 |
-
except Exception:
|
| 760 |
-
return False
|
| 761 |
-
|
| 762 |
-
# in class TextToSQLSystem:
|
| 763 |
-
|
| 764 |
def process_question(self, question: str) -> Tuple[str, str]:
|
| 765 |
-
|
| 766 |
-
# 檢查緩存
|
| 767 |
-
if question in self.query_cache:
|
| 768 |
-
self._log("⚡ 使用緩存結果")
|
| 769 |
-
return self.query_cache[question]
|
| 770 |
-
|
| 771 |
self.log_history = []
|
| 772 |
self._log(f"⏰ 處理問題: {question}")
|
| 773 |
-
|
| 774 |
-
# 1. 檢索相似範例
|
| 775 |
-
self._log("🔍 尋找相似範例...")
|
| 776 |
examples = self.find_most_similar(question, FEW_SHOT_EXAMPLES_COUNT)
|
| 777 |
if examples: self._log(f"✅ 找到 {len(examples)} 個相似範例")
|
| 778 |
-
|
| 779 |
-
# 2. 建立提示詞
|
| 780 |
-
self._log("📝 建立 Prompt...")
|
| 781 |
prompt = self._build_prompt(question, examples)
|
| 782 |
-
|
| 783 |
-
# 3. 生成 AI 回應
|
| 784 |
self._log("🧠 開始生成 AI 回應...")
|
| 785 |
response = self.huggingface_api_call(prompt)
|
| 786 |
-
|
| 787 |
-
# 4. **新的核心步驟**: 呼叫決策引擎來生成最終 SQL
|
| 788 |
final_sql, status_message = self._validate_and_fix_sql(question, response)
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
result = (final_sql, status_message)
|
| 792 |
-
else:
|
| 793 |
-
result = (status_message, "生成失敗")
|
| 794 |
-
|
| 795 |
-
# 緩存結果
|
| 796 |
self.query_cache[question] = result
|
| 797 |
return result
|
| 798 |
|
|
@@ -800,53 +332,36 @@ SQL:
|
|
| 800 |
text_to_sql_system = TextToSQLSystem()
|
| 801 |
|
| 802 |
def process_query(q: str):
|
| 803 |
-
if not q.strip():
|
| 804 |
-
|
| 805 |
-
|
| 806 |
sql, status = text_to_sql_system.process_question(q)
|
| 807 |
-
logs = "\n".join(text_to_sql_system.log_history[-
|
| 808 |
-
|
| 809 |
return sql, status, logs
|
| 810 |
|
| 811 |
-
# 範例問題
|
| 812 |
examples = [
|
| 813 |
-
"2024
|
| 814 |
-
"
|
| 815 |
-
"
|
| 816 |
-
"
|
| 817 |
-
"A
|
|
|
|
| 818 |
]
|
| 819 |
-
|
| 820 |
with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
|
| 821 |
-
gr.Markdown("# ⚡ Text-to-SQL 智能助手")
|
| 822 |
-
gr.Markdown("
|
| 823 |
-
|
| 824 |
with gr.Row():
|
| 825 |
with gr.Column(scale=2):
|
| 826 |
inp = gr.Textbox(lines=3, label="💬 您的問題", placeholder="例如:2024年每月完成多少份報告?")
|
| 827 |
btn = gr.Button("🚀 生成 SQL", variant="primary")
|
| 828 |
status = gr.Textbox(label="狀態", interactive=False)
|
| 829 |
-
|
| 830 |
with gr.Column(scale=3):
|
| 831 |
sql_out = gr.Code(label="🤖 生成的 SQL", language="sql", lines=8)
|
| 832 |
-
|
| 833 |
with gr.Accordion("📋 處理日誌", open=False):
|
| 834 |
-
logs = gr.Textbox(lines=
|
| 835 |
-
|
| 836 |
-
# 範例區
|
| 837 |
-
gr.Examples(
|
| 838 |
-
examples=examples,
|
| 839 |
-
inputs=inp,
|
| 840 |
-
label="💡 點擊試用範例問題"
|
| 841 |
-
)
|
| 842 |
-
|
| 843 |
-
# 綁定事件
|
| 844 |
btn.click(process_query, inputs=[inp], outputs=[sql_out, status, logs])
|
| 845 |
inp.submit(process_query, inputs=[inp], outputs=[sql_out, status, logs])
|
| 846 |
|
| 847 |
if __name__ == "__main__":
|
| 848 |
-
demo.launch(
|
| 849 |
-
server_name="0.0.0.0",
|
| 850 |
-
server_port=7860,
|
| 851 |
-
share=False
|
| 852 |
-
)
|
|
|
|
| 1 |
+
# ==============================================================================
|
| 2 |
+
# Text-to-SQL 智能助手 - Hugging Face CPU 最终版 v6
|
| 3 |
+
# (融合模板引擎 + 强化 Prompt + 修复所有 Bug)
|
| 4 |
+
# ==============================================================================
|
| 5 |
import gradio as gr
|
| 6 |
import os
|
| 7 |
import re
|
| 8 |
import json
|
| 9 |
import torch
|
| 10 |
import numpy as np
|
| 11 |
+
import gc
|
| 12 |
+
import tempfile
|
| 13 |
from datetime import datetime
|
| 14 |
from datasets import load_dataset
|
| 15 |
from huggingface_hub import hf_hub_download
|
| 16 |
from llama_cpp import Llama
|
| 17 |
from typing import List, Dict, Tuple, Optional
|
| 18 |
import faiss
|
| 19 |
+
import traceback
|
| 20 |
|
|
|
|
| 21 |
from transformers import AutoModel, AutoTokenizer
|
| 22 |
import torch.nn.functional as F
|
| 23 |
|
| 24 |
+
# ==================== 配置參數 ====================
|
| 25 |
+
# --- Hugging Face CPU 部署配置 ---
|
|
|
|
|
|
|
| 26 |
GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q8_0.gguf"
|
| 27 |
+
N_GPU_LAYERS = 0 # 在 Hugging Face CPU 环境下设置为 0
|
| 28 |
|
| 29 |
+
DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
|
| 30 |
+
GGUF_REPO_ID = "Paul720810/gguf-models"
|
|
|
|
| 31 |
FEW_SHOT_EXAMPLES_COUNT = 1
|
| 32 |
+
DEVICE = "cuda" if torch.cuda.is_available() and N_GPU_LAYERS != 0 else "cpu"
|
| 33 |
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 34 |
|
| 35 |
+
TEMP_DIR = tempfile.gettempdir()
|
| 36 |
+
os.makedirs(os.path.join(TEMP_DIR, 'text_to_sql_cache'), exist_ok=True)
|
| 37 |
+
|
| 38 |
print("=" * 60)
|
| 39 |
+
print("🤖 Text-to-SQL 智能助手 v6.0 (Hugging Face CPU 版)...")
|
| 40 |
+
print(f"🚀 模型: {GGUF_FILENAME}")
|
| 41 |
+
print(f"💻 設備: {DEVICE} (GPU Layers: {N_GPU_LAYERS})")
|
|
|
|
| 42 |
print("=" * 60)
|
| 43 |
|
| 44 |
# ==================== 工具函數 ====================
|
|
|
|
| 46 |
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
| 47 |
|
| 48 |
def format_log(message: str, level: str = "INFO") -> str:
|
| 49 |
+
log_entry = f"[{get_current_time()}] [{level.upper()}] {message}"
|
| 50 |
+
print(log_entry)
|
| 51 |
+
return log_entry
|
| 52 |
|
| 53 |
def parse_sql_from_response(response_text: str) -> Optional[str]:
|
| 54 |
+
if not response_text: return None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
response_text = response_text.strip()
|
|
|
|
|
|
|
| 56 |
match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
|
| 57 |
+
if match: return match.group(1).strip()
|
|
|
|
|
|
|
|
|
|
| 58 |
match = re.search(r"```\s*\n?(.*?)\n?```", response_text, re.DOTALL)
|
| 59 |
if match:
|
| 60 |
sql_candidate = match.group(1).strip()
|
| 61 |
+
if sql_candidate.upper().startswith('SELECT'): return sql_candidate
|
|
|
|
|
|
|
|
|
|
| 62 |
match = re.search(r"(SELECT\s+.*?;)", response_text, re.DOTALL | re.IGNORECASE)
|
| 63 |
+
if match: return match.group(1).strip()
|
|
|
|
|
|
|
|
|
|
| 64 |
match = re.search(r"(SELECT\s+.*?)(?=\n\n|\n```|$|\n[^,\s])", response_text, re.DOTALL | re.IGNORECASE)
|
| 65 |
if match:
|
| 66 |
sql = match.group(1).strip()
|
| 67 |
+
if not sql.endswith(';'): sql += ';'
|
|
|
|
| 68 |
return sql
|
|
|
|
|
|
|
| 69 |
if 'SELECT' in response_text.upper():
|
| 70 |
+
for line in response_text.split('\n'):
|
|
|
|
| 71 |
line = line.strip()
|
| 72 |
if line.upper().startswith('SELECT'):
|
| 73 |
+
if not line.endswith(';'): line += ';'
|
|
|
|
| 74 |
return line
|
|
|
|
| 75 |
return None
|
| 76 |
|
| 77 |
# ==================== Text-to-SQL 核心類 ====================
|
|
|
|
| 80 |
self.log_history = []
|
| 81 |
self._log("初始化系統...")
|
| 82 |
self.query_cache = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
try:
|
| 84 |
+
self._log(f"載入嵌入模型: {embed_model_name}")
|
| 85 |
+
self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
|
| 86 |
+
self.embed_model = AutoModel.from_pretrained(embed_model_name)
|
| 87 |
+
if DEVICE == "cuda":
|
| 88 |
+
self.embed_model.to(DEVICE)
|
| 89 |
+
|
| 90 |
+
self.schema = self._load_schema()
|
| 91 |
+
self.dataset, self.faiss_index = self._load_and_index_dataset()
|
| 92 |
+
self._load_gguf_model()
|
| 93 |
+
self._log("✅ 系統初始化完成")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
except Exception as e:
|
| 95 |
+
self._log(f"❌ 系統初始化過程中發生嚴重錯誤: {e}", "CRITICAL")
|
| 96 |
+
self._log(traceback.format_exc(), "DEBUG")
|
| 97 |
self.llm = None
|
| 98 |
|
| 99 |
+
def _log(self, message: str, level: str = "INFO"):
|
| 100 |
+
self.log_history.append(format_log(message, level))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
def _load_gguf_model(self):
|
|
|
|
| 103 |
try:
|
| 104 |
+
model_path = hf_hub_download(repo_id=GGUF_REPO_ID, filename=GGUF_FILENAME, repo_type="dataset", cache_dir=TEMP_DIR)
|
| 105 |
+
self._log(f"模型路徑: {model_path}")
|
| 106 |
+
self._log(f"載入 GGUF 模型 (GPU Layers: {N_GPU_LAYERS})...")
|
| 107 |
+
self.llm = Llama(model_path=model_path, n_ctx=2048, n_threads=4, n_batch=512, verbose=False, n_gpu_layers=N_GPU_LAYERS)
|
| 108 |
+
self._log("✅ GGUF 模型成功載入")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
except Exception as e:
|
| 110 |
+
self._log(f"❌ GGUF 載入失敗: {e}", "CRITICAL")
|
| 111 |
self.llm = None
|
| 112 |
|
| 113 |
def huggingface_api_call(self, prompt: str) -> str:
|
| 114 |
+
if self.llm is None: return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
try:
|
| 116 |
+
output = self.llm(prompt, max_tokens=150, temperature=0.1, top_p=0.9, echo=False, stop=["```", ";", "\n\n", "</s>", "###", "Q:"], repeat_penalty=1.1)
|
| 117 |
+
generated_text = output["choices"][0]["text"] if output and "choices" in output and len(output["choices"]) > 0 else ""
|
| 118 |
+
self._log(f"🧠 模型原始輸出: {generated_text.strip()}", "DEBUG")
|
| 119 |
+
return generated_text.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
except Exception as e:
|
| 121 |
+
self._log(f"❌ 模型生成錯誤: {e}", "CRITICAL")
|
|
|
|
|
|
|
| 122 |
return ""
|
| 123 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
def _load_schema(self) -> Dict:
|
|
|
|
| 125 |
try:
|
| 126 |
+
schema_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename="sqlite_schema_FULL.json", repo_type="dataset")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
with open(schema_path, "r", encoding="utf-8") as f:
|
| 128 |
schema_data = json.load(f)
|
| 129 |
+
self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格。")
|
| 130 |
+
return schema_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
except Exception as e:
|
| 132 |
self._log(f"❌ 載入 schema 失敗: {e}", "ERROR")
|
| 133 |
return {}
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
def _encode_texts(self, texts):
|
| 136 |
+
if isinstance(texts, str): texts = [texts]
|
| 137 |
+
inputs = self.embed_tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512).to(DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
with torch.no_grad():
|
| 139 |
outputs = self.embed_model(**inputs)
|
|
|
|
|
|
|
| 140 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 141 |
return embeddings.cpu()
|
| 142 |
|
| 143 |
def _load_and_index_dataset(self):
|
|
|
|
| 144 |
try:
|
| 145 |
dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
|
| 146 |
+
dataset = dataset.filter(lambda ex: isinstance(ex.get("messages"), list) and len(ex["messages"]) >= 2)
|
| 147 |
+
corpus = [item['messages']['content'] for item in dataset]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
self._log(f"正在編碼 {len(corpus)} 個問題...")
|
| 149 |
+
all_embeddings = torch.cat([self._encode_texts(corpus[i:i+32]) for i in range(0, len(corpus), 32)], dim=0).numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
index = faiss.IndexFlatIP(all_embeddings.shape[1])
|
| 151 |
index.add(all_embeddings.astype('float32'))
|
|
|
|
| 152 |
self._log("✅ 向量索引建立完成")
|
| 153 |
return dataset, index
|
|
|
|
| 154 |
except Exception as e:
|
| 155 |
self._log(f"❌ 載入數據失敗: {e}", "ERROR")
|
| 156 |
+
self._log(traceback.format_exc(), "DEBUG")
|
| 157 |
return None, None
|
| 158 |
+
|
| 159 |
def _identify_relevant_tables(self, question: str) -> List[str]:
|
|
|
|
| 160 |
question_lower = question.lower()
|
| 161 |
relevant_tables = []
|
| 162 |
+
keyword_to_table = {'TSR53SampleDescription': ['客戶', '買方', '申請', '發票對象'], 'JobsInProgress': ['進行中', '買家', '申請方'], 'JobTimeline': ['時間', '完成', '創建', '實驗室'], 'TSR53Invoice': ['發票', '金額', '費用']}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
for table, keywords in keyword_to_table.items():
|
| 164 |
+
if any(keyword in question_lower for keyword in keywords): relevant_tables.append(table)
|
| 165 |
+
if not relevant_tables: return ['TSR53SampleDescription', 'JobsInProgress', 'JobTimeline']
|
| 166 |
+
return relevant_tables[:3]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
def _format_relevant_schema(self, table_names: List[str]) -> str:
|
| 169 |
+
if not self.schema: return "No schema available.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
formatted = ""
|
| 171 |
+
for table in table_names:
|
| 172 |
if table in self.schema:
|
|
|
|
| 173 |
formatted += f"Table: {table}\n"
|
| 174 |
cols_str = []
|
|
|
|
| 175 |
for col in self.schema[table][:10]:
|
| 176 |
+
col_name, col_type, col_desc = col['name'], col['type'], col.get('description', '').replace('\n', ' ')
|
| 177 |
+
if col_desc: cols_str.append(f"{col_name} ({col_type}, {col_desc})")
|
| 178 |
+
else: cols_str.append(f"{col_name} ({col_type})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
formatted += f"Columns: {', '.join(cols_str)}\n\n"
|
|
|
|
| 180 |
return formatted.strip()
|
| 181 |
|
| 182 |
+
def find_most_similar(self, question: str, top_k: int) -> List[Dict]:
|
| 183 |
+
if self.faiss_index is None: return []
|
| 184 |
+
try:
|
| 185 |
+
q_embedding = self._encode_texts([question]).numpy().astype('float32')
|
| 186 |
+
distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset)))
|
| 187 |
+
results, seen_questions = [], set()
|
| 188 |
+
for i, idx in enumerate(indices[0]):
|
| 189 |
+
if len(results) >= top_k: break
|
| 190 |
+
idx = int(idx)
|
| 191 |
+
if idx >= len(self.dataset): continue
|
| 192 |
+
item = self.dataset[idx]
|
| 193 |
+
if not (isinstance(item.get('messages'), list) and len(item['messages']) >= 2): continue
|
| 194 |
+
q_content = (item['messages']['content'] or '').strip()
|
| 195 |
+
a_content = (item['messages'].get('content') or '').strip()
|
| 196 |
+
if not q_content or not a_content: continue
|
| 197 |
+
clean_q = re.sub(r"以下是一個SQL查詢任務:\s*指令:\s*", "", q_content).strip()
|
| 198 |
+
if clean_q in seen_questions: continue
|
| 199 |
+
seen_questions.add(clean_q)
|
| 200 |
+
sql = parse_sql_from_response(a_content) or "無法解析範例SQL"
|
| 201 |
+
results.append({"similarity": float(distances[0][i]), "question": clean_q, "sql": sql})
|
| 202 |
+
return results
|
| 203 |
+
except Exception as e:
|
| 204 |
+
self._log(f"❌ 檢索失敗: {e}", "ERROR")
|
| 205 |
+
return []
|
| 206 |
+
|
| 207 |
+
def _build_prompt(self, user_q: str, examples: List[Dict]) -> str:
|
| 208 |
+
schema_str = self._format_relevant_schema(self._identify_relevant_tables(user_q))
|
| 209 |
+
example_str = ""
|
| 210 |
+
if examples:
|
| 211 |
+
example_prompts = [f"Q: {ex['question']}\nA: ```sql\n{ex['sql']}\n```" for ex in examples]
|
| 212 |
+
example_str = "\n---\n".join(example_prompts)
|
| 213 |
+
prompt = f"""You are an expert SQLite programmer. Your task is to generate a SQL query based on the database schema and a user's question.
|
| 214 |
+
|
| 215 |
+
## Database Schema
|
| 216 |
+
{schema_str.strip()}
|
| 217 |
+
|
| 218 |
+
## Examples
|
| 219 |
+
{example_str.strip()}
|
| 220 |
+
|
| 221 |
+
## Task
|
| 222 |
+
Based on the schema and examples, generate the SQL query for the following question.
|
| 223 |
+
Q: {user_q}
|
| 224 |
+
A: ```sql
|
| 225 |
+
"""
|
| 226 |
+
return prompt
|
| 227 |
+
|
| 228 |
+
def _finalize_sql(self, sql: str, log_message: str) -> Tuple[str, str]:
|
| 229 |
+
final_sql = re.sub(r'\s+', ' ', sql.strip())
|
| 230 |
+
if not final_sql.endswith(';'): final_sql += ';'
|
| 231 |
+
self._log(f"✅ SQL 已生成 ({log_message})", "INFO")
|
| 232 |
+
self._log(f" - 最終 SQL: {final_sql}", "DEBUG")
|
| 233 |
+
return final_sql, "生成成功"
|
| 234 |
|
| 235 |
def _validate_and_fix_sql(self, question: str, raw_response: str) -> Tuple[Optional[str], str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
q_lower = question.lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
entity_match_data = None
|
|
|
|
|
|
|
| 238 |
entity_patterns = [
|
|
|
|
| 239 |
{'pattern': r"(买家|buyer)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.BuyerID', 'type': '买家ID'},
|
| 240 |
{'pattern': r"(申请方|申请厂商|applicant)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.ApplicantID', 'type': '申请方ID'},
|
| 241 |
{'pattern': r"(付款方|付款厂商|invoiceto)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.InvoiceToID', 'type': '付款方ID'},
|
| 242 |
{'pattern': r"(代理商|agent)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.AgentID', 'type': '代理商ID'},
|
|
|
|
|
|
|
| 243 |
{'pattern': r"(买家|buyer|客戶)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.BuyerName', 'type': '买家'},
|
| 244 |
{'pattern': r"(申请方|申请厂商|applicant)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.ApplicantName', 'type': '申请方'},
|
| 245 |
{'pattern': r"(付款方|付款厂商|invoiceto)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.InvoiceToName', 'type': '付款方'},
|
| 246 |
{'pattern': r"(代理商|agent)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.AgentName', 'type': '代理商'},
|
|
|
|
|
|
|
| 247 |
{'pattern': r"\b([A-Z]\d{4}[A-Z])\b", 'column': 'sd.ApplicantID', 'type': 'ID'}
|
| 248 |
]
|
|
|
|
| 249 |
for p in entity_patterns:
|
| 250 |
match = re.search(p['pattern'], question, re.IGNORECASE)
|
| 251 |
if match:
|
| 252 |
entity_value = match.group(2) if len(match.groups()) > 1 else match.group(1)
|
| 253 |
+
entity_match_data = {"type": p['type'], "name": entity_value.strip().upper(), "column": p['column']}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
break
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
if any(kw in q_lower for kw in ['報告號碼', '報告清單', '列出報告', 'report number', 'list of reports']):
|
| 257 |
year_match = re.search(r'(\d{4})\s*年?', question)
|
| 258 |
month_match = re.search(r'(\d{1,2})\s*月', question)
|
|
|
|
| 260 |
select_clause = "SELECT jt.JobNo, jt.ReportAuthorization"
|
| 261 |
where_conditions = ["jt.ReportAuthorization IS NOT NULL"]
|
| 262 |
log_parts = []
|
| 263 |
+
if year_match: where_conditions.append(f"strftime('%Y', jt.ReportAuthorization) = '{year_match.group(1)}'"); log_parts.append(f"{year_match.group(1)}年")
|
| 264 |
+
if month_match: where_conditions.append(f"strftime('%m', jt.ReportAuthorization) = '{month_match.group(1).zfill(2)}'"); log_parts.append(f"{month_match.group(1)}月")
|
|
|
|
|
|
|
| 265 |
if 'fail' in q_lower or '失敗' in q_lower:
|
| 266 |
+
if "JOIN TSR53SampleDescription" not in from_clause: from_clause += " JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
|
| 267 |
where_conditions.append("sd.OverallRating = 'Fail'"); log_parts.append("Fail")
|
| 268 |
elif 'pass' in q_lower or '通過' in q_lower:
|
| 269 |
+
if "JOIN TSR53SampleDescription" not in from_clause: from_clause += " JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
|
| 270 |
where_conditions.append("sd.OverallRating = 'Pass'"); log_parts.append("Pass")
|
|
|
|
| 271 |
if entity_match_data:
|
| 272 |
entity_name, column_name = entity_match_data["name"], entity_match_data["column"]
|
| 273 |
+
if "JOIN TSR53SampleDescription" not in from_clause: from_clause += " JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
|
| 274 |
match_operator = "=" if column_name.endswith("ID") else "LIKE"
|
| 275 |
entity_value = f"'{entity_name}'" if match_operator == "=" else f"'%{entity_name}%'"
|
| 276 |
where_conditions.append(f"{column_name} {match_operator} {entity_value}")
|
| 277 |
log_parts.append(entity_name)
|
| 278 |
select_clause = "SELECT jt.JobNo, sd.BuyerName, jt.ReportAuthorization"
|
| 279 |
+
final_where_clause = "WHERE " + " AND ".join(where_conditions) if where_conditions else ""
|
|
|
|
| 280 |
time_log = " ".join(log_parts) if log_parts else "全部"
|
| 281 |
self._log(f"🔄 檢測到查詢【{time_log} 報告列表】意圖,啟用智能模板。", "INFO")
|
| 282 |
template_sql = f"{select_clause} {from_clause} {final_where_clause} ORDER BY jt.ReportAuthorization DESC;"
|
| 283 |
return self._finalize_sql(template_sql, f"模板覆寫: {time_log} 報告列表查詢")
|
| 284 |
|
| 285 |
+
if '報告' in q_lower and any(kw in q_lower for kw in ['幾份', '多少', '數量', '總數']) and not entity_match_data:
|
|
|
|
| 286 |
year_match = re.search(r'(\d{4})\s*年?', question)
|
| 287 |
time_condition, time_log = "", "總"
|
| 288 |
if year_match:
|
| 289 |
+
time_condition = f"WHERE ReportAuthorization IS NOT NULL AND strftime('%Y', ReportAuthorization) = '{year_match.group(1)}'"
|
| 290 |
+
time_log = f"{year_match.group(1)}年"
|
|
|
|
| 291 |
else:
|
| 292 |
time_condition = "WHERE ReportAuthorization IS NOT NULL"
|
| 293 |
self._log(f"🔄 檢測到查詢【{time_log}全局報告總數】意圖,啟用模板。", "INFO")
|
| 294 |
template_sql = f"SELECT COUNT(DISTINCT JobNo) AS report_count FROM JobTimeline {time_condition};"
|
| 295 |
return self._finalize_sql(template_sql, f"模板覆寫: {time_log}全局報告總數查詢")
|
| 296 |
|
|
|
|
|
|
|
|
|
|
| 297 |
self._log("未觸發任何模板,嘗試解析並修正 AI 輸出...", "INFO")
|
|
|
|
| 298 |
parsed_sql = parse_sql_from_response(raw_response)
|
| 299 |
if not parsed_sql:
|
|
|
|
| 300 |
return None, f"無法解析SQL。原始回應:\n{raw_response}"
|
|
|
|
|
|
|
|
|
|
| 301 |
fixed_sql = " " + parsed_sql.strip() + " "
|
| 302 |
fixes_applied_fallback = []
|
|
|
|
| 303 |
dialect_corrections = {r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)"}
|
| 304 |
+
for p, r in dialect_corrections.items():
|
| 305 |
+
if re.search(p, fixed_sql, re.IGNORECASE):
|
| 306 |
+
fixed_sql = re.sub(p, r, fixed_sql, flags=re.IGNORECASE); fixes_applied_fallback.append(f"修正方言: {p}")
|
| 307 |
+
schema_corrections = {'TSR53Report':'TSR53SampleDescription', 'TSR53InvoiceReportNo':'JobNo', 'Status':'OverallRating'}
|
| 308 |
+
for w, c in schema_corrections.items():
|
| 309 |
+
pattern = r'\b' + re.escape(w) + r'\b'
|
|
|
|
|
|
|
| 310 |
if re.search(pattern, fixed_sql, re.IGNORECASE):
|
| 311 |
+
fixed_sql = re.sub(pattern, c, fixed_sql, flags=re.IGNORECASE); fixes_applied_fallback.append(f"映射 Schema: '{w}' -> '{c}'")
|
|
|
|
|
|
|
| 312 |
log_msg = "AI 生成並成功修正" if fixes_applied_fallback else "AI 生成且無需修正"
|
| 313 |
return self._finalize_sql(fixed_sql, log_msg)
|
| 314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
def process_question(self, question: str) -> Tuple[str, str]:
|
| 316 |
+
if question in self.query_cache: self._log("⚡ 使用緩存結果"); return self.query_cache[question]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
self.log_history = []
|
| 318 |
self._log(f"⏰ 處理問題: {question}")
|
|
|
|
|
|
|
|
|
|
| 319 |
examples = self.find_most_similar(question, FEW_SHOT_EXAMPLES_COUNT)
|
| 320 |
if examples: self._log(f"✅ 找到 {len(examples)} 個相似範例")
|
|
|
|
|
|
|
|
|
|
| 321 |
prompt = self._build_prompt(question, examples)
|
| 322 |
+
self._log(f"📏 Prompt 長度: {len(prompt)} 字符")
|
|
|
|
| 323 |
self._log("🧠 開始生成 AI 回應...")
|
| 324 |
response = self.huggingface_api_call(prompt)
|
|
|
|
|
|
|
| 325 |
final_sql, status_message = self._validate_and_fix_sql(question, response)
|
| 326 |
+
if not final_sql: result = (status_message, "生成失敗")
|
| 327 |
+
else: result = (final_sql, status_message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
self.query_cache[question] = result
|
| 329 |
return result
|
| 330 |
|
|
|
|
| 332 |
text_to_sql_system = TextToSQLSystem()
|
| 333 |
|
| 334 |
def process_query(q: str):
|
| 335 |
+
if not q.strip(): return "", "等待輸入", "請輸入問題"
|
| 336 |
+
if text_to_sql_system.llm is None:
|
| 337 |
+
return "模型未能成功載入,請檢查終端日誌。", "模型載入失敗", "\n".join(text_to_sql_system.log_history)
|
| 338 |
sql, status = text_to_sql_system.process_question(q)
|
| 339 |
+
logs = "\n".join(text_to_sql_system.log_history[-15:])
|
|
|
|
| 340 |
return sql, status, logs
|
| 341 |
|
|
|
|
| 342 |
examples = [
|
| 343 |
+
"2024年7月買家 Gap 的 Fail 報告號碼",
|
| 344 |
+
"列出2023年所有失败的报告",
|
| 345 |
+
"找出总金额最高的10个工作单",
|
| 346 |
+
"哪些客户的工作单数量最多?",
|
| 347 |
+
"A組2024年完成了多少個測試項目?",
|
| 348 |
+
"2024年每月完成多少份報告?"
|
| 349 |
]
|
|
|
|
| 350 |
with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
|
| 351 |
+
gr.Markdown("# ⚡ Text-to-SQL 智能助手 (终极版)")
|
| 352 |
+
gr.Markdown("融合了模板引擎和 GGUF 模型的强大版本")
|
|
|
|
| 353 |
with gr.Row():
|
| 354 |
with gr.Column(scale=2):
|
| 355 |
inp = gr.Textbox(lines=3, label="💬 您的問題", placeholder="例如:2024年每月完成多少份報告?")
|
| 356 |
btn = gr.Button("🚀 生成 SQL", variant="primary")
|
| 357 |
status = gr.Textbox(label="狀態", interactive=False)
|
|
|
|
| 358 |
with gr.Column(scale=3):
|
| 359 |
sql_out = gr.Code(label="🤖 生成的 SQL", language="sql", lines=8)
|
|
|
|
| 360 |
with gr.Accordion("📋 處理日誌", open=False):
|
| 361 |
+
logs = gr.Textbox(lines=10, label="日誌", interactive=False)
|
| 362 |
+
gr.Examples(examples=examples, inputs=inp, label="💡 點擊試用範例問題")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
btn.click(process_query, inputs=[inp], outputs=[sql_out, status, logs])
|
| 364 |
inp.submit(process_query, inputs=[inp], outputs=[sql_out, status, logs])
|
| 365 |
|
| 366 |
if __name__ == "__main__":
|
| 367 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
|
|
|
|
|
|
|
|
|
|
|