Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# ==============================================================================
|
| 2 |
-
# Text-to-SQL 智能助手 - Hugging Face CPU 最终版
|
| 3 |
-
# (
|
| 4 |
# ==============================================================================
|
| 5 |
import gradio as gr
|
| 6 |
import os
|
|
@@ -22,57 +22,32 @@ from transformers import AutoModel, AutoTokenizer
|
|
| 22 |
import torch.nn.functional as F
|
| 23 |
|
| 24 |
# ==================== 配置參數 ====================
|
| 25 |
-
# --- Hugging Face CPU 部署配置 ---
|
| 26 |
GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q8_0.gguf"
|
| 27 |
-
N_GPU_LAYERS = 0 #
|
| 28 |
-
|
| 29 |
DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
|
| 30 |
GGUF_REPO_ID = "Paul720810/gguf-models"
|
| 31 |
FEW_SHOT_EXAMPLES_COUNT = 1
|
| 32 |
-
DEVICE = "
|
| 33 |
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 34 |
-
|
| 35 |
TEMP_DIR = tempfile.gettempdir()
|
| 36 |
os.makedirs(os.path.join(TEMP_DIR, 'text_to_sql_cache'), exist_ok=True)
|
| 37 |
|
| 38 |
print("=" * 60)
|
| 39 |
-
print("🤖 Text-to-SQL 智能助手
|
| 40 |
print(f"🚀 模型: {GGUF_FILENAME}")
|
| 41 |
print(f"💻 設備: {DEVICE} (GPU Layers: {N_GPU_LAYERS})")
|
| 42 |
print("=" * 60)
|
| 43 |
|
| 44 |
# ==================== 工具函數 ====================
|
| 45 |
-
def get_current_time():
|
| 46 |
-
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
| 47 |
-
|
| 48 |
def format_log(message: str, level: str = "INFO") -> str:
|
| 49 |
-
log_entry = f"[{get_current_time()}] [{level.upper()}] {message}"
|
| 50 |
-
print(log_entry)
|
| 51 |
-
return log_entry
|
| 52 |
-
|
| 53 |
def parse_sql_from_response(response_text: str) -> Optional[str]:
|
| 54 |
if not response_text: return None
|
| 55 |
response_text = response_text.strip()
|
| 56 |
match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
|
| 57 |
if match: return match.group(1).strip()
|
| 58 |
-
|
| 59 |
-
if match:
|
| 60 |
-
sql_candidate = match.group(1).strip()
|
| 61 |
-
if sql_candidate.upper().startswith('SELECT'): return sql_candidate
|
| 62 |
-
match = re.search(r"(SELECT\s+.*?;)", response_text, re.DOTALL | re.IGNORECASE)
|
| 63 |
-
if match: return match.group(1).strip()
|
| 64 |
-
match = re.search(r"(SELECT\s+.*?)(?=\n\n|\n```|$|\n[^,\s])", response_text, re.DOTALL | re.IGNORECASE)
|
| 65 |
-
if match:
|
| 66 |
-
sql = match.group(1).strip()
|
| 67 |
-
if not sql.endswith(';'): sql += ';'
|
| 68 |
-
return sql
|
| 69 |
-
if 'SELECT' in response_text.upper():
|
| 70 |
-
for line in response_text.split('\n'):
|
| 71 |
-
line = line.strip()
|
| 72 |
-
if line.upper().startswith('SELECT'):
|
| 73 |
-
if not line.endswith(';'): line += ';'
|
| 74 |
-
return line
|
| 75 |
-
return None
|
| 76 |
|
| 77 |
# ==================== Text-to-SQL 核心類 ====================
|
| 78 |
class TextToSQLSystem:
|
|
@@ -84,21 +59,15 @@ class TextToSQLSystem:
|
|
| 84 |
self._log(f"載入嵌入模型: {embed_model_name}")
|
| 85 |
self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
|
| 86 |
self.embed_model = AutoModel.from_pretrained(embed_model_name)
|
| 87 |
-
if DEVICE == "cuda":
|
| 88 |
-
self.embed_model.to(DEVICE)
|
| 89 |
-
|
| 90 |
self.schema = self._load_schema()
|
| 91 |
self.dataset, self.faiss_index = self._load_and_index_dataset()
|
| 92 |
self._load_gguf_model()
|
| 93 |
self._log("✅ 系統初始化完成")
|
| 94 |
except Exception as e:
|
| 95 |
-
self._log(f"❌ 系統初始化過程中發生嚴重錯誤: {e}", "CRITICAL")
|
| 96 |
-
self._log(traceback.format_exc(), "DEBUG")
|
| 97 |
-
self.llm = None
|
| 98 |
-
|
| 99 |
-
def _log(self, message: str, level: str = "INFO"):
|
| 100 |
-
self.log_history.append(format_log(message, level))
|
| 101 |
|
|
|
|
|
|
|
| 102 |
def _load_gguf_model(self):
|
| 103 |
try:
|
| 104 |
model_path = hf_hub_download(repo_id=GGUF_REPO_ID, filename=GGUF_FILENAME, repo_type="dataset", cache_dir=TEMP_DIR)
|
|
@@ -107,54 +76,47 @@ class TextToSQLSystem:
|
|
| 107 |
self.llm = Llama(model_path=model_path, n_ctx=2048, n_threads=4, n_batch=512, verbose=False, n_gpu_layers=N_GPU_LAYERS)
|
| 108 |
self._log("✅ GGUF 模型成功載入")
|
| 109 |
except Exception as e:
|
| 110 |
-
self._log(f"❌ GGUF 載入失敗: {e}", "CRITICAL")
|
| 111 |
-
self.llm = None
|
| 112 |
|
| 113 |
def huggingface_api_call(self, prompt: str) -> str:
|
| 114 |
if self.llm is None: return ""
|
| 115 |
try:
|
| 116 |
output = self.llm(prompt, max_tokens=150, temperature=0.1, top_p=0.9, echo=False, stop=["```", ";", "\n\n", "</s>", "###", "Q:"], repeat_penalty=1.1)
|
| 117 |
-
generated_text = output["choices"][
|
| 118 |
self._log(f"🧠 模型原始輸出: {generated_text.strip()}", "DEBUG")
|
| 119 |
return generated_text.strip()
|
| 120 |
except Exception as e:
|
| 121 |
-
self._log(f"❌ 模型生成錯誤: {e}", "CRITICAL")
|
| 122 |
-
return ""
|
| 123 |
|
| 124 |
def _load_schema(self) -> Dict:
|
| 125 |
try:
|
| 126 |
schema_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename="sqlite_schema_FULL.json", repo_type="dataset")
|
| 127 |
-
with open(schema_path, "r", encoding="utf-8") as f:
|
| 128 |
-
schema_data = json.load(f)
|
| 129 |
self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格。")
|
| 130 |
return schema_data
|
| 131 |
except Exception as e:
|
| 132 |
-
self._log(f"❌ 載入 schema 失敗: {e}", "ERROR")
|
| 133 |
-
return {}
|
| 134 |
|
| 135 |
def _encode_texts(self, texts):
|
| 136 |
if isinstance(texts, str): texts = [texts]
|
| 137 |
-
inputs = self.embed_tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
|
| 138 |
with torch.no_grad():
|
| 139 |
outputs = self.embed_model(**inputs)
|
| 140 |
-
|
| 141 |
-
return embeddings.cpu()
|
| 142 |
|
| 143 |
def _load_and_index_dataset(self):
|
| 144 |
try:
|
| 145 |
dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
|
| 146 |
dataset = dataset.filter(lambda ex: isinstance(ex.get("messages"), list) and len(ex["messages"]) >= 2)
|
| 147 |
-
corpus = [item['messages'][
|
| 148 |
self._log(f"正在編碼 {len(corpus)} 個問題...")
|
| 149 |
all_embeddings = torch.cat([self._encode_texts(corpus[i:i+32]) for i in range(0, len(corpus), 32)], dim=0).numpy()
|
| 150 |
-
index = faiss.IndexFlatIP(all_embeddings.shape
|
| 151 |
index.add(all_embeddings.astype('float32'))
|
| 152 |
self._log("✅ 向量索引建立完成")
|
| 153 |
return dataset, index
|
| 154 |
except Exception as e:
|
| 155 |
-
self._log(f"❌ 載入數據失敗: {e}", "ERROR")
|
| 156 |
-
self._log(traceback.format_exc(), "DEBUG")
|
| 157 |
-
return None, None
|
| 158 |
|
| 159 |
def _identify_relevant_tables(self, question: str) -> List[str]:
|
| 160 |
question_lower = question.lower()
|
|
@@ -185,31 +147,27 @@ class TextToSQLSystem:
|
|
| 185 |
q_embedding = self._encode_texts([question]).numpy().astype('float32')
|
| 186 |
distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset)))
|
| 187 |
results, seen_questions = [], set()
|
| 188 |
-
for i, idx in enumerate(indices
|
| 189 |
if len(results) >= top_k: break
|
| 190 |
idx = int(idx)
|
| 191 |
if idx >= len(self.dataset): continue
|
| 192 |
item = self.dataset[idx]
|
| 193 |
if not (isinstance(item.get('messages'), list) and len(item['messages']) >= 2): continue
|
| 194 |
-
q_content = (item['messages']
|
| 195 |
-
a_content = (item['messages']
|
| 196 |
if not q_content or not a_content: continue
|
| 197 |
clean_q = re.sub(r"以下是一個SQL查詢任務:\s*指令:\s*", "", q_content).strip()
|
| 198 |
if clean_q in seen_questions: continue
|
| 199 |
seen_questions.add(clean_q)
|
| 200 |
sql = parse_sql_from_response(a_content) or "無法解析範例SQL"
|
| 201 |
-
results.append({"similarity": float(distances[
|
| 202 |
return results
|
| 203 |
except Exception as e:
|
| 204 |
-
self._log(f"❌ 檢索失敗: {e}", "ERROR")
|
| 205 |
-
return []
|
| 206 |
|
| 207 |
def _build_prompt(self, user_q: str, examples: List[Dict]) -> str:
|
| 208 |
schema_str = self._format_relevant_schema(self._identify_relevant_tables(user_q))
|
| 209 |
-
example_str = ""
|
| 210 |
-
if examples:
|
| 211 |
-
example_prompts = [f"Q: {ex['question']}\nA: ```sql\n{ex['sql']}\n```" for ex in examples]
|
| 212 |
-
example_str = "\n---\n".join(example_prompts)
|
| 213 |
prompt = f"""You are an expert SQLite programmer. Your task is to generate a SQL query based on the database schema and a user's question.
|
| 214 |
|
| 215 |
## Database Schema
|
|
@@ -238,12 +196,8 @@ A: ```sql
|
|
| 238 |
entity_patterns = [
|
| 239 |
{'pattern': r"(买家|buyer)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.BuyerID', 'type': '买家ID'},
|
| 240 |
{'pattern': r"(申请方|申请厂商|applicant)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.ApplicantID', 'type': '申请方ID'},
|
| 241 |
-
{'pattern': r"(付款方|付款厂商|invoiceto)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.InvoiceToID', 'type': '付款方ID'},
|
| 242 |
-
{'pattern': r"(代理商|agent)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.AgentID', 'type': '代理商ID'},
|
| 243 |
{'pattern': r"(买家|buyer|客戶)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.BuyerName', 'type': '买家'},
|
| 244 |
{'pattern': r"(申请方|申请厂商|applicant)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.ApplicantName', 'type': '申请方'},
|
| 245 |
-
{'pattern': r"(付款方|付款厂商|invoiceto)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.InvoiceToName', 'type': '付款方'},
|
| 246 |
-
{'pattern': r"(代理商|agent)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.AgentName', 'type': '代理商'},
|
| 247 |
{'pattern': r"\b([A-Z]\d{4}[A-Z])\b", 'column': 'sd.ApplicantID', 'type': 'ID'}
|
| 248 |
]
|
| 249 |
for p in entity_patterns:
|
|
@@ -253,64 +207,38 @@ A: ```sql
|
|
| 253 |
entity_match_data = {"type": p['type'], "name": entity_value.strip().upper(), "column": p['column']}
|
| 254 |
break
|
| 255 |
|
| 256 |
-
if any(kw in q_lower for kw in ['報告號碼', '報告清單', '列出報告'
|
| 257 |
-
year_match = re.search(r'(\d{4})\s*年?', question)
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
if
|
| 264 |
-
|
| 265 |
-
if 'fail' in q_lower or '失敗' in q_lower:
|
| 266 |
-
if "JOIN TSR53SampleDescription" not in from_clause: from_clause += " JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
|
| 267 |
-
where_conditions.append("sd.OverallRating = 'Fail'"); log_parts.append("Fail")
|
| 268 |
-
elif 'pass' in q_lower or '通過' in q_lower:
|
| 269 |
-
if "JOIN TSR53SampleDescription" not in from_clause: from_clause += " JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
|
| 270 |
-
where_conditions.append("sd.OverallRating = 'Pass'"); log_parts.append("Pass")
|
| 271 |
if entity_match_data:
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
time_log = " ".join(log_parts) if log_parts else "全部"
|
| 281 |
-
self._log(f"🔄 檢測到查詢【{time_log} 報告列表】意圖,啟用智能模板。", "INFO")
|
| 282 |
-
template_sql = f"{select_clause} {from_clause} {final_where_clause} ORDER BY jt.ReportAuthorization DESC;"
|
| 283 |
-
return self._finalize_sql(template_sql, f"模板覆寫: {time_log} 報告列表查詢")
|
| 284 |
|
| 285 |
if '報告' in q_lower and any(kw in q_lower for kw in ['幾份', '多少', '數量', '總數']) and not entity_match_data:
|
| 286 |
year_match = re.search(r'(\d{4})\s*年?', question)
|
| 287 |
-
|
| 288 |
-
if year_match
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
return self._finalize_sql(template_sql, f"模板覆寫: {time_log}全局報告總數查詢")
|
| 296 |
-
|
| 297 |
-
self._log("未觸發任何模板,嘗試解析並修正 AI 輸出...", "INFO")
|
| 298 |
parsed_sql = parse_sql_from_response(raw_response)
|
| 299 |
-
if not parsed_sql:
|
| 300 |
-
|
| 301 |
-
fixed_sql = " " + parsed_sql.strip() + " "
|
| 302 |
-
fixes_applied_fallback = []
|
| 303 |
-
dialect_corrections = {r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)"}
|
| 304 |
-
for p, r in dialect_corrections.items():
|
| 305 |
-
if re.search(p, fixed_sql, re.IGNORECASE):
|
| 306 |
-
fixed_sql = re.sub(p, r, fixed_sql, flags=re.IGNORECASE); fixes_applied_fallback.append(f"修正方言: {p}")
|
| 307 |
-
schema_corrections = {'TSR53Report':'TSR53SampleDescription', 'TSR53InvoiceReportNo':'JobNo', 'Status':'OverallRating'}
|
| 308 |
-
for w, c in schema_corrections.items():
|
| 309 |
-
pattern = r'\b' + re.escape(w) + r'\b'
|
| 310 |
-
if re.search(pattern, fixed_sql, re.IGNORECASE):
|
| 311 |
-
fixed_sql = re.sub(pattern, c, fixed_sql, flags=re.IGNORECASE); fixes_applied_fallback.append(f"映射 Schema: '{w}' -> '{c}'")
|
| 312 |
-
log_msg = "AI 生成並成功修正" if fixes_applied_fallback else "AI 生成且無需修正"
|
| 313 |
-
return self._finalize_sql(fixed_sql, log_msg)
|
| 314 |
|
| 315 |
def process_question(self, question: str) -> Tuple[str, str]:
|
| 316 |
if question in self.query_cache: self._log("⚡ 使用緩存結果"); return self.query_cache[question]
|
|
@@ -323,42 +251,28 @@ A: ```sql
|
|
| 323 |
self._log("🧠 開始生成 AI 回應...")
|
| 324 |
response = self.huggingface_api_call(prompt)
|
| 325 |
final_sql, status_message = self._validate_and_fix_sql(question, response)
|
| 326 |
-
|
| 327 |
-
else: result = (final_sql, status_message)
|
| 328 |
self.query_cache[question] = result
|
| 329 |
return result
|
| 330 |
|
| 331 |
# ==================== Gradio 介面 ====================
|
| 332 |
text_to_sql_system = TextToSQLSystem()
|
| 333 |
-
|
| 334 |
def process_query(q: str):
|
| 335 |
if not q.strip(): return "", "等待輸入", "請輸入問題"
|
| 336 |
-
if text_to_sql_system.llm is None:
|
| 337 |
-
return "模型未能成功載入,請檢查終端日誌。", "模型載入失敗", "\n".join(text_to_sql_system.log_history)
|
| 338 |
sql, status = text_to_sql_system.process_question(q)
|
| 339 |
-
|
| 340 |
-
return sql, status, logs
|
| 341 |
|
| 342 |
-
examples = [
|
| 343 |
-
"2024年7月買家 Gap 的 Fail 報告號碼",
|
| 344 |
-
"列出2023年所有失败的报告",
|
| 345 |
-
"找出总金额最高的10个工作单",
|
| 346 |
-
"哪些客户的工作单数量最多?",
|
| 347 |
-
"A組2024年完成了多少個測試項目?",
|
| 348 |
-
"2024年每月完成多少份報告?"
|
| 349 |
-
]
|
| 350 |
with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
|
| 351 |
gr.Markdown("# ⚡ Text-to-SQL 智能助手 (终极版)")
|
| 352 |
-
gr.Markdown("融合了模板引擎和 GGUF 模型的强大版本")
|
| 353 |
with gr.Row():
|
| 354 |
with gr.Column(scale=2):
|
| 355 |
inp = gr.Textbox(lines=3, label="💬 您的問題", placeholder="例如:2024年每月完成多少份報告?")
|
| 356 |
btn = gr.Button("🚀 生成 SQL", variant="primary")
|
| 357 |
status = gr.Textbox(label="狀態", interactive=False)
|
| 358 |
-
with gr.Column(scale=3):
|
| 359 |
-
|
| 360 |
-
with gr.Accordion("📋 處理日誌", open=False):
|
| 361 |
-
logs = gr.Textbox(lines=10, label="日誌", interactive=False)
|
| 362 |
gr.Examples(examples=examples, inputs=inp, label="💡 點擊試用範例問題")
|
| 363 |
btn.click(process_query, inputs=[inp], outputs=[sql_out, status, logs])
|
| 364 |
inp.submit(process_query, inputs=[inp], outputs=[sql_out, status, logs])
|
|
|
|
| 1 |
# ==============================================================================
|
| 2 |
+
# Text-to-SQL 智能助手 - Hugging Face CPU 最终版 v7
|
| 3 |
+
# (修复所有已知 Bug)
|
| 4 |
# ==============================================================================
|
| 5 |
import gradio as gr
|
| 6 |
import os
|
|
|
|
| 22 |
import torch.nn.functional as F
|
| 23 |
|
| 24 |
# ==================== 配置參數 ====================
|
|
|
|
| 25 |
GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q8_0.gguf"
|
| 26 |
+
N_GPU_LAYERS = 0 # CPU 环境
|
|
|
|
| 27 |
DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
|
| 28 |
GGUF_REPO_ID = "Paul720810/gguf-models"
|
| 29 |
FEW_SHOT_EXAMPLES_COUNT = 1
|
| 30 |
+
DEVICE = "cpu"
|
| 31 |
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
|
|
|
| 32 |
TEMP_DIR = tempfile.gettempdir()
|
| 33 |
os.makedirs(os.path.join(TEMP_DIR, 'text_to_sql_cache'), exist_ok=True)
|
| 34 |
|
| 35 |
print("=" * 60)
|
| 36 |
+
print("🤖 Text-to-SQL 智能助手 v7.0 (Hugging Face CPU 版)...")
|
| 37 |
print(f"🚀 模型: {GGUF_FILENAME}")
|
| 38 |
print(f"💻 設備: {DEVICE} (GPU Layers: {N_GPU_LAYERS})")
|
| 39 |
print("=" * 60)
|
| 40 |
|
| 41 |
# ==================== 工具函數 ====================
|
| 42 |
+
def get_current_time(): return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
|
|
|
|
|
| 43 |
def format_log(message: str, level: str = "INFO") -> str:
|
| 44 |
+
log_entry = f"[{get_current_time()}] [{level.upper()}] {message}"; print(log_entry); return log_entry
|
|
|
|
|
|
|
|
|
|
| 45 |
def parse_sql_from_response(response_text: str) -> Optional[str]:
|
| 46 |
if not response_text: return None
|
| 47 |
response_text = response_text.strip()
|
| 48 |
match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE)
|
| 49 |
if match: return match.group(1).strip()
|
| 50 |
+
return response_text if response_text.upper().startswith("SELECT") else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# ==================== Text-to-SQL 核心類 ====================
|
| 53 |
class TextToSQLSystem:
|
|
|
|
| 59 |
self._log(f"載入嵌入模型: {embed_model_name}")
|
| 60 |
self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
|
| 61 |
self.embed_model = AutoModel.from_pretrained(embed_model_name)
|
|
|
|
|
|
|
|
|
|
| 62 |
self.schema = self._load_schema()
|
| 63 |
self.dataset, self.faiss_index = self._load_and_index_dataset()
|
| 64 |
self._load_gguf_model()
|
| 65 |
self._log("✅ 系統初始化完成")
|
| 66 |
except Exception as e:
|
| 67 |
+
self._log(f"❌ 系統初始化過程中發生嚴重錯誤: {e}", "CRITICAL"); self._log(traceback.format_exc(), "DEBUG"); self.llm = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
+
def _log(self, message: str, level: str = "INFO"): self.log_history.append(format_log(message, level))
|
| 70 |
+
|
| 71 |
def _load_gguf_model(self):
|
| 72 |
try:
|
| 73 |
model_path = hf_hub_download(repo_id=GGUF_REPO_ID, filename=GGUF_FILENAME, repo_type="dataset", cache_dir=TEMP_DIR)
|
|
|
|
| 76 |
self.llm = Llama(model_path=model_path, n_ctx=2048, n_threads=4, n_batch=512, verbose=False, n_gpu_layers=N_GPU_LAYERS)
|
| 77 |
self._log("✅ GGUF 模型成功載入")
|
| 78 |
except Exception as e:
|
| 79 |
+
self._log(f"❌ GGUF 載入失敗: {e}", "CRITICAL"); self.llm = None
|
|
|
|
| 80 |
|
| 81 |
def huggingface_api_call(self, prompt: str) -> str:
|
| 82 |
if self.llm is None: return ""
|
| 83 |
try:
|
| 84 |
output = self.llm(prompt, max_tokens=150, temperature=0.1, top_p=0.9, echo=False, stop=["```", ";", "\n\n", "</s>", "###", "Q:"], repeat_penalty=1.1)
|
| 85 |
+
generated_text = output["choices"]["text"] if output and "choices" in output and len(output["choices"]) > 0 else ""
|
| 86 |
self._log(f"🧠 模型原始輸出: {generated_text.strip()}", "DEBUG")
|
| 87 |
return generated_text.strip()
|
| 88 |
except Exception as e:
|
| 89 |
+
self._log(f"❌ 模型生成錯誤: {e}", "CRITICAL"); return ""
|
|
|
|
| 90 |
|
| 91 |
def _load_schema(self) -> Dict:
|
| 92 |
try:
|
| 93 |
schema_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename="sqlite_schema_FULL.json", repo_type="dataset")
|
| 94 |
+
with open(schema_path, "r", encoding="utf-8") as f: schema_data = json.load(f)
|
|
|
|
| 95 |
self._log(f"📊 Schema 載入成功,包含 {len(schema_data)} 個表格。")
|
| 96 |
return schema_data
|
| 97 |
except Exception as e:
|
| 98 |
+
self._log(f"❌ 載入 schema 失敗: {e}", "ERROR"); return {}
|
|
|
|
| 99 |
|
| 100 |
def _encode_texts(self, texts):
|
| 101 |
if isinstance(texts, str): texts = [texts]
|
| 102 |
+
inputs = self.embed_tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
|
| 103 |
with torch.no_grad():
|
| 104 |
outputs = self.embed_model(**inputs)
|
| 105 |
+
return outputs.last_hidden_state.mean(dim=1).cpu()
|
|
|
|
| 106 |
|
| 107 |
def _load_and_index_dataset(self):
|
| 108 |
try:
|
| 109 |
dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
|
| 110 |
dataset = dataset.filter(lambda ex: isinstance(ex.get("messages"), list) and len(ex["messages"]) >= 2)
|
| 111 |
+
corpus = [item['messages']['content'] for item in dataset]
|
| 112 |
self._log(f"正在編碼 {len(corpus)} 個問題...")
|
| 113 |
all_embeddings = torch.cat([self._encode_texts(corpus[i:i+32]) for i in range(0, len(corpus), 32)], dim=0).numpy()
|
| 114 |
+
index = faiss.IndexFlatIP(all_embeddings.shape)
|
| 115 |
index.add(all_embeddings.astype('float32'))
|
| 116 |
self._log("✅ 向量索引建立完成")
|
| 117 |
return dataset, index
|
| 118 |
except Exception as e:
|
| 119 |
+
self._log(f"❌ 載入數據失敗: {e}", "ERROR"); self._log(traceback.format_exc(), "DEBUG"); return None, None
|
|
|
|
|
|
|
| 120 |
|
| 121 |
def _identify_relevant_tables(self, question: str) -> List[str]:
|
| 122 |
question_lower = question.lower()
|
|
|
|
| 147 |
q_embedding = self._encode_texts([question]).numpy().astype('float32')
|
| 148 |
distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset)))
|
| 149 |
results, seen_questions = [], set()
|
| 150 |
+
for i, idx in enumerate(indices):
|
| 151 |
if len(results) >= top_k: break
|
| 152 |
idx = int(idx)
|
| 153 |
if idx >= len(self.dataset): continue
|
| 154 |
item = self.dataset[idx]
|
| 155 |
if not (isinstance(item.get('messages'), list) and len(item['messages']) >= 2): continue
|
| 156 |
+
q_content = (item['messages'].get('content') or '').strip()
|
| 157 |
+
a_content = (item['messages'].get('content') or '').strip()
|
| 158 |
if not q_content or not a_content: continue
|
| 159 |
clean_q = re.sub(r"以下是一個SQL查詢任務:\s*指令:\s*", "", q_content).strip()
|
| 160 |
if clean_q in seen_questions: continue
|
| 161 |
seen_questions.add(clean_q)
|
| 162 |
sql = parse_sql_from_response(a_content) or "無法解析範例SQL"
|
| 163 |
+
results.append({"similarity": float(distances[i]), "question": clean_q, "sql": sql})
|
| 164 |
return results
|
| 165 |
except Exception as e:
|
| 166 |
+
self._log(f"❌ 檢索失敗: {e}", "ERROR"); return []
|
|
|
|
| 167 |
|
| 168 |
def _build_prompt(self, user_q: str, examples: List[Dict]) -> str:
|
| 169 |
schema_str = self._format_relevant_schema(self._identify_relevant_tables(user_q))
|
| 170 |
+
example_str = "\n---\n".join([f"Q: {ex['question']}\nA: ```sql\n{ex['sql']}\n```" for ex in examples]) if examples else "No examples provided."
|
|
|
|
|
|
|
|
|
|
| 171 |
prompt = f"""You are an expert SQLite programmer. Your task is to generate a SQL query based on the database schema and a user's question.
|
| 172 |
|
| 173 |
## Database Schema
|
|
|
|
| 196 |
entity_patterns = [
|
| 197 |
{'pattern': r"(买家|buyer)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.BuyerID', 'type': '买家ID'},
|
| 198 |
{'pattern': r"(申请方|申请厂商|applicant)\s*(?:id|代號|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.ApplicantID', 'type': '申请方ID'},
|
|
|
|
|
|
|
| 199 |
{'pattern': r"(买家|buyer|客戶)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.BuyerName', 'type': '买家'},
|
| 200 |
{'pattern': r"(申请方|申请厂商|applicant)\s*'\"?([a-zA-Z0-9&.\s-]+?)(?:\s*的|\s+|$|有)", 'column': 'sd.ApplicantName', 'type': '申请方'},
|
|
|
|
|
|
|
| 201 |
{'pattern': r"\b([A-Z]\d{4}[A-Z])\b", 'column': 'sd.ApplicantID', 'type': 'ID'}
|
| 202 |
]
|
| 203 |
for p in entity_patterns:
|
|
|
|
| 207 |
entity_match_data = {"type": p['type'], "name": entity_value.strip().upper(), "column": p['column']}
|
| 208 |
break
|
| 209 |
|
| 210 |
+
if any(kw in q_lower for kw in ['報告號碼', '報告清單', '列出報告']):
|
| 211 |
+
year_match, month_match = re.search(r'(\d{4})\s*年?', question), re.search(r'(\d{1,2})\s*月', question)
|
| 212 |
+
from_clause = "FROM JobTimeline AS jt JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo"
|
| 213 |
+
select_clause = "SELECT jt.JobNo, sd.BuyerName, jt.ReportAuthorization"
|
| 214 |
+
where, log_parts = ["jt.ReportAuthorization IS NOT NULL"], []
|
| 215 |
+
if year_match: where.append(f"strftime('%Y', jt.ReportAuthorization) = '{year_match.group(1)}'"); log_parts.append(f"{year_match.group(1)}年")
|
| 216 |
+
if month_match: where.append(f"strftime('%m', jt.ReportAuthorization) = '{month_match.group(1).zfill(2)}'"); log_parts.append(f"{month_match.group(1)}月")
|
| 217 |
+
if 'fail' in q_lower or '失敗' in q_lower: where.append("sd.OverallRating = 'Fail'"); log_parts.append("Fail")
|
| 218 |
+
elif 'pass' in q_lower or '通過' in q_lower: where.append("sd.OverallRating = 'Pass'"); log_parts.append("Pass")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
if entity_match_data:
|
| 220 |
+
op = "=" if entity_match_data["column"].endswith("ID") else "LIKE"
|
| 221 |
+
val = f"'{entity_match_data['name']}'" if op == "=" else f"'%{entity_match_data['name']}%'"
|
| 222 |
+
where.append(f"{entity_match_data['column']} {op} {val}"); log_parts.append(entity_match_data["name"])
|
| 223 |
+
final_where = "WHERE " + " AND ".join(where) if where else ""
|
| 224 |
+
log_msg = " ".join(log_parts) if log_parts else "全部"
|
| 225 |
+
self._log(f"🔄 檢測到【{log_msg} 報告列表】意圖,啟用模板。", "INFO")
|
| 226 |
+
template_sql = f"{select_clause} {from_clause} {final_where} ORDER BY jt.ReportAuthorization DESC;"
|
| 227 |
+
return self._finalize_sql(template_sql, f"模板覆寫: {log_msg}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
if '報告' in q_lower and any(kw in q_lower for kw in ['幾份', '多少', '數量', '總數']) and not entity_match_data:
|
| 230 |
year_match = re.search(r'(\d{4})\s*年?', question)
|
| 231 |
+
time_cond = f"WHERE strftime('%Y', ReportAuthorization) = '{year_match.group(1)}'" if year_match else ""
|
| 232 |
+
time_log = f"{year_match.group(1)}年" if year_match else "总"
|
| 233 |
+
self._log(f"🔄 檢測到【{time_log}报告总数】意图,启用模板。", "INFO")
|
| 234 |
+
template_sql = f"SELECT COUNT(DISTINCT JobNo) AS report_count FROM JobTimeline {time_cond};"
|
| 235 |
+
return self._finalize_sql(template_sql, f"模板覆寫: {time_log}")
|
| 236 |
+
|
| 237 |
+
# Fallback to AI
|
| 238 |
+
self._log("未觸發模板,由 AI 生成...", "INFO")
|
|
|
|
|
|
|
|
|
|
| 239 |
parsed_sql = parse_sql_from_response(raw_response)
|
| 240 |
+
if not parsed_sql: return None, f"無法解析SQL: {raw_response}"
|
| 241 |
+
return self._finalize_sql(parsed_sql, "AI 生成")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
def process_question(self, question: str) -> Tuple[str, str]:
|
| 244 |
if question in self.query_cache: self._log("⚡ 使用緩存結果"); return self.query_cache[question]
|
|
|
|
| 251 |
self._log("🧠 開始生成 AI 回應...")
|
| 252 |
response = self.huggingface_api_call(prompt)
|
| 253 |
final_sql, status_message = self._validate_and_fix_sql(question, response)
|
| 254 |
+
result = (final_sql, status_message) if final_sql else (status_message, "生成失敗")
|
|
|
|
| 255 |
self.query_cache[question] = result
|
| 256 |
return result
|
| 257 |
|
| 258 |
# ==================== Gradio 介面 ====================
|
| 259 |
text_to_sql_system = TextToSQLSystem()
|
|
|
|
| 260 |
def process_query(q: str):
|
| 261 |
if not q.strip(): return "", "等待輸入", "請輸入問題"
|
| 262 |
+
if text_to_sql_system.llm is None: return "模型未能成功載入...", "模型載入失敗", "\n".join(text_to_sql_system.log_history)
|
|
|
|
| 263 |
sql, status = text_to_sql_system.process_question(q)
|
| 264 |
+
return sql, status, "\n".join(text_to_sql_system.log_history[-15:])
|
|
|
|
| 265 |
|
| 266 |
+
examples = ["2024年7月買家 Gap 的 Fail 報告號碼", "列出2023年所有失败的报告", "找出总金额最高的10个工作单", "哪些客户的工作单数量最多?", "A組2024年完成了多少個測試項目?", "2024年每月完成多少份報告?"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
|
| 268 |
gr.Markdown("# ⚡ Text-to-SQL 智能助手 (终极版)")
|
|
|
|
| 269 |
with gr.Row():
|
| 270 |
with gr.Column(scale=2):
|
| 271 |
inp = gr.Textbox(lines=3, label="💬 您的問題", placeholder="例如:2024年每月完成多少份報告?")
|
| 272 |
btn = gr.Button("🚀 生成 SQL", variant="primary")
|
| 273 |
status = gr.Textbox(label="狀態", interactive=False)
|
| 274 |
+
with gr.Column(scale=3): sql_out = gr.Code(label="🤖 生成的 SQL", language="sql", lines=8)
|
| 275 |
+
with gr.Accordion("📋 處理日誌", open=False): logs = gr.Textbox(lines=10, label="日誌", interactive=False)
|
|
|
|
|
|
|
| 276 |
gr.Examples(examples=examples, inputs=inp, label="💡 點擊試用範例問題")
|
| 277 |
btn.click(process_query, inputs=[inp], outputs=[sql_out, status, logs])
|
| 278 |
inp.submit(process_query, inputs=[inp], outputs=[sql_out, status, logs])
|