Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,29 +12,30 @@ from typing import List, Dict, Tuple, Optional
|
|
| 12 |
import numpy as np
|
| 13 |
|
| 14 |
# ==================== 配置區 ====================
|
| 15 |
-
HF_TOKEN = os.environ.get("HF_TOKEN",
|
| 16 |
DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
|
| 17 |
-
SIMILARITY_THRESHOLD = 0.
|
| 18 |
|
| 19 |
-
# 多個備用LLM模型
|
| 20 |
LLM_MODELS = [
|
| 21 |
"https://api-inference.huggingface.co/models/gpt2",
|
| 22 |
-
"https://api-inference.huggingface.co/models/distilgpt2",
|
| 23 |
"https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
|
| 24 |
]
|
| 25 |
|
| 26 |
print("=" * 60)
|
| 27 |
print("🤖 智能 Text-to-SQL 系統啟動中...")
|
| 28 |
-
print("📊 模式:
|
| 29 |
print("=" * 60)
|
| 30 |
|
| 31 |
-
# ====================
|
| 32 |
def get_current_time():
|
|
|
|
| 33 |
return datetime.now().strftime("%H:%M:%S")
|
| 34 |
|
| 35 |
def validate_sql(sql_query: str) -> Dict:
|
| 36 |
-
"""驗證SQL
|
| 37 |
-
if not sql_query or sql_query.strip()
|
| 38 |
return {"valid": False, "issues": ["SQL語句為空"], "is_safe": False, "empty": True}
|
| 39 |
|
| 40 |
sql_clean = sql_query.strip()
|
|
@@ -47,7 +48,7 @@ def validate_sql(sql_query: str) -> Dict:
|
|
| 47 |
# 檢查危險操作
|
| 48 |
dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
|
| 49 |
for keyword in dangerous_keywords:
|
| 50 |
-
if f" {keyword} " in sql_upper:
|
| 51 |
security_issues.append(f"危險操作: {keyword}")
|
| 52 |
|
| 53 |
# 檢查基本語法
|
|
@@ -57,15 +58,18 @@ def validate_sql(sql_query: str) -> Dict:
|
|
| 57 |
if "FROM" not in sql_upper:
|
| 58 |
security_issues.append("缺少FROM")
|
| 59 |
|
|
|
|
|
|
|
|
|
|
| 60 |
return {
|
| 61 |
-
"valid":
|
| 62 |
"issues": security_issues,
|
| 63 |
-
"is_safe":
|
| 64 |
"empty": False
|
| 65 |
}
|
| 66 |
|
| 67 |
def analyze_question_type(question: str) -> Dict:
|
| 68 |
-
"""
|
| 69 |
question_lower = question.lower()
|
| 70 |
|
| 71 |
analysis = {
|
|
@@ -79,26 +83,24 @@ def analyze_question_type(question: str) -> Dict:
|
|
| 79 |
|
| 80 |
# 檢測關鍵詞
|
| 81 |
keywords_sets = {
|
| 82 |
-
"sales": ["銷售", "業績", "金額", "收入", "sale", "revenue"
|
| 83 |
-
"customer": ["客戶", "買家", "用戶", "customer", "client"
|
| 84 |
-
"product": ["產品", "商品", "項目", "product", "item"
|
| 85 |
-
"time": ["時間", "日期", "月份", "年", "月", "最近", "date", "month", "year"
|
| 86 |
-
"report": ["報告", "完成", "份", "report", "complete"
|
| 87 |
-
"count": ["多少", "幾個", "數量", "count", "how many"
|
| 88 |
-
"comparison": ["比較", "vs", " versus", "對比", "相比"
|
| 89 |
}
|
| 90 |
|
| 91 |
for category, keywords in keywords_sets.items():
|
| 92 |
-
for keyword in keywords:
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
analysis["keywords"].append(category)
|
| 96 |
-
|
| 97 |
# 特殊檢測
|
| 98 |
-
analysis["has_count"] =
|
| 99 |
-
analysis["has_date"] =
|
| 100 |
analysis["has_group"] = any(word in question_lower for word in ["每", "各", "group", "每個"])
|
| 101 |
-
analysis["has_comparison"] =
|
| 102 |
|
| 103 |
# 確定主要類型
|
| 104 |
if analysis["keywords"]:
|
|
@@ -106,57 +108,6 @@ def analyze_question_type(question: str) -> Dict:
|
|
| 106 |
|
| 107 |
return analysis
|
| 108 |
|
| 109 |
-
def generate_sql_from_question(question: str, analysis: Dict) -> str:
|
| 110 |
-
"""根據問題分析生成智能SQL"""
|
| 111 |
-
question_lower = question.lower()
|
| 112 |
-
question_type = analysis["type"]
|
| 113 |
-
|
| 114 |
-
# 針對常見問題模式的SQL生成
|
| 115 |
-
if "每月" in question_lower and ("完成" in question_lower or "報告" in question_lower):
|
| 116 |
-
year_match = re.search(r'(\d{4})年', question_lower)
|
| 117 |
-
year = year_match.group(1) if year_match else "2023"
|
| 118 |
-
return f"SELECT strftime('%Y-%m', completion_date) as month, COUNT(*) as report_count FROM reports WHERE strftime('%Y', completion_date) = '{year}' GROUP BY month ORDER BY month;"
|
| 119 |
-
|
| 120 |
-
elif "銷售" in question_lower and ("最高" in question_lower or "最好" in question_lower):
|
| 121 |
-
return "SELECT product_name, SUM(sales_amount) as total_sales FROM sales GROUP BY product_name ORDER BY total_sales DESC LIMIT 10;"
|
| 122 |
-
|
| 123 |
-
elif "客戶" in question_lower and ("訂單" in question_lower or "購買" in question_lower):
|
| 124 |
-
return "SELECT customer_name, COUNT(*) as order_count, SUM(order_amount) as total_spent FROM orders GROUP BY customer_name ORDER BY total_spent DESC;"
|
| 125 |
-
|
| 126 |
-
elif "比較" in question_lower and ("年" in question_lower or "年份" in question_lower):
|
| 127 |
-
return "SELECT strftime('%Y', order_date) as year, COUNT(*) as order_count, SUM(order_amount) as yearly_revenue FROM orders GROUP BY year ORDER BY year;"
|
| 128 |
-
|
| 129 |
-
elif "庫存" in question_lower and ("不足" in question_lower or "缺少" in question_lower):
|
| 130 |
-
return "SELECT product_name, stock_quantity FROM products WHERE stock_quantity < 10 ORDER BY stock_quantity ASC;"
|
| 131 |
-
|
| 132 |
-
# 根據分析結果生成通用SQL
|
| 133 |
-
if analysis["has_count"] and analysis["has_group"] and analysis["has_date"]:
|
| 134 |
-
return "SELECT strftime('%Y-%m', date_column) as period, COUNT(*) as item_count FROM appropriate_table GROUP BY period ORDER BY period;"
|
| 135 |
-
|
| 136 |
-
elif analysis["has_count"] and analysis["has_group"]:
|
| 137 |
-
return "SELECT category_column, COUNT(*) as count FROM appropriate_table GROUP BY category_column ORDER BY count DESC;"
|
| 138 |
-
|
| 139 |
-
elif analysis["has_count"]:
|
| 140 |
-
return "SELECT COUNT(*) as total_count FROM appropriate_table;"
|
| 141 |
-
|
| 142 |
-
elif analysis["has_group"]:
|
| 143 |
-
return "SELECT group_column, AVG(value_column) as average_value FROM appropriate_table GROUP BY group_column;"
|
| 144 |
-
|
| 145 |
-
else:
|
| 146 |
-
return "SELECT * FROM appropriate_table LIMIT 10;"
|
| 147 |
-
|
| 148 |
-
def repair_empty_sql(original_sql: str, user_question: str, similar_question: str) -> str:
|
| 149 |
-
"""修復空白SQL"""
|
| 150 |
-
if not original_sql or original_sql.strip() == "":
|
| 151 |
-
# 分析問題並生成合適的SQL
|
| 152 |
-
analysis = analyze_question_type(user_question)
|
| 153 |
-
repaired_sql = generate_sql_from_question(user_question, analysis)
|
| 154 |
-
|
| 155 |
-
# 添加註釋說明這是修復的SQL
|
| 156 |
-
return f"-- 根據類似問題 '{similar_question}' 修復生成的SQL\n{repaired_sql}"
|
| 157 |
-
|
| 158 |
-
return original_sql
|
| 159 |
-
|
| 160 |
# ==================== 完整數據加載模塊 ====================
|
| 161 |
class CompleteDataLoader:
|
| 162 |
def __init__(self, hf_token: str):
|
|
@@ -173,9 +124,7 @@ class CompleteDataLoader:
|
|
| 173 |
raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
|
| 174 |
|
| 175 |
print("解析全部 messages 格式...")
|
| 176 |
-
total_count = 0
|
| 177 |
-
empty_count = 0
|
| 178 |
-
valid_count = 0
|
| 179 |
|
| 180 |
for item in raw_dataset:
|
| 181 |
try:
|
|
@@ -191,7 +140,7 @@ class CompleteDataLoader:
|
|
| 191 |
sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
|
| 192 |
if sql_match:
|
| 193 |
sql_query = sql_match.group(1).strip()
|
| 194 |
-
sql_query = re.sub(r'^sql\s*', '', sql_query)
|
| 195 |
sql_query = re.sub(r'```sql|```', '', sql_query).strip()
|
| 196 |
else:
|
| 197 |
sql_query = assistant_content
|
|
@@ -210,8 +159,7 @@ class CompleteDataLoader:
|
|
| 210 |
empty_count += 1
|
| 211 |
if validation["valid"]:
|
| 212 |
valid_count += 1
|
| 213 |
-
|
| 214 |
-
except Exception as e:
|
| 215 |
continue
|
| 216 |
|
| 217 |
print(f"數據加載完成: 總數 {total_count}, 有效 {valid_count}, 空白 {empty_count}")
|
|
@@ -239,13 +187,39 @@ class CompleteDataLoader:
|
|
| 239 |
self.schema_data = {}
|
| 240 |
return False
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
# ==================== 主系統 ====================
|
| 243 |
class CompleteTextToSQLSystem:
|
| 244 |
def __init__(self, hf_token: str):
|
| 245 |
self.hf_token = hf_token
|
| 246 |
self.data_loader = CompleteDataLoader(hf_token)
|
| 247 |
self.retrieval_system = RetrievalSystem()
|
| 248 |
-
|
| 249 |
self.initialize_system()
|
| 250 |
|
| 251 |
def initialize_system(self):
|
|
@@ -255,20 +229,93 @@ class CompleteTextToSQLSystem:
|
|
| 255 |
self.data_loader.load_complete_dataset()
|
| 256 |
self.data_loader.load_schema()
|
| 257 |
|
| 258 |
-
#
|
| 259 |
if self.data_loader.questions:
|
| 260 |
self.retrieval_system.compute_embeddings(self.data_loader.questions)
|
| 261 |
|
| 262 |
print(f"系統初始化完成,載入問題總數: {len(self.data_loader.questions)}")
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
def generate_sql(self, user_question: str) -> Tuple[str, str]:
|
| 265 |
-
"""
|
| 266 |
log_messages = [f"⏰ {get_current_time()} 開始處理"]
|
| 267 |
|
| 268 |
-
if not user_question or user_question.strip()
|
| 269 |
return "請輸入您的問題。", "錯誤: 問題為空"
|
| 270 |
|
| 271 |
-
# 1.
|
| 272 |
if self.data_loader.questions:
|
| 273 |
hits = self.retrieval_system.retrieve_similar(user_question)
|
| 274 |
|
|
@@ -278,119 +325,92 @@ class CompleteTextToSQLSystem:
|
|
| 278 |
corpus_id = best_hit['corpus_id']
|
| 279 |
similar_question = self.data_loader.questions[corpus_id]
|
| 280 |
original_sql = self.data_loader.sql_answers[corpus_id]
|
| 281 |
-
sql_quality = self.data_loader.sql_quality[corpus_id]
|
| 282 |
|
| 283 |
-
log_messages.append(f"🔍
|
| 284 |
-
log_messages.append(f"📊 相似度: {similarity_score:.3f}
|
| 285 |
|
| 286 |
if similarity_score > SIMILARITY_THRESHOLD:
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
if validation["empty"] or not validation["valid"]:
|
| 291 |
-
log_messages.append(f"⚠️ 原始SQL需要修復: {', '.join(validation['issues'])}")
|
| 292 |
-
log_messages.append("🛠️ 正在智能修復SQL...")
|
| 293 |
-
|
| 294 |
-
repaired_sql = repair_empty_sql(original_sql, user_question, similar_question)
|
| 295 |
-
log_messages.append("✅ 修復完成")
|
| 296 |
-
|
| 297 |
-
return repaired_sql, "\n".join(log_messages)
|
| 298 |
-
else:
|
| 299 |
-
log_messages.append(f"✅ 相似度 > {SIMILARITY_THRESHOLD},使用預先SQL")
|
| 300 |
-
return original_sql, "\n".join(log_messages)
|
| 301 |
else:
|
| 302 |
-
log_messages.append(f"ℹ️
|
| 303 |
-
|
| 304 |
# 2. 如果檢索失敗或相似度不足,智能生成SQL
|
| 305 |
-
log_messages.append("🤖
|
| 306 |
analysis = analyze_question_type(user_question)
|
| 307 |
-
intelligent_sql = generate_sql_from_question(user_question, analysis)
|
| 308 |
|
| 309 |
-
log_messages.append(f"📋 問題分析: {analysis['type']}
|
| 310 |
-
log_messages.append("✅
|
| 311 |
|
| 312 |
return intelligent_sql, "\n".join(log_messages)
|
| 313 |
|
| 314 |
-
# ==================== 其他類定義 ====================
|
| 315 |
-
class LLMClient:
|
| 316 |
-
def __init__(self, hf_token: str):
|
| 317 |
-
self.hf_token = hf_token
|
| 318 |
-
|
| 319 |
-
def call_llm_api(self, prompt: str) -> Optional[str]:
|
| 320 |
-
headers = {"Authorization": f"Bearer {self.hf_token}"}
|
| 321 |
-
payload = {"inputs": prompt, "parameters": {"max_new_tokens": 200, "temperature": 0.1}}
|
| 322 |
-
|
| 323 |
-
for model_url in LLM_MODELS:
|
| 324 |
-
try:
|
| 325 |
-
response = requests.post(model_url, headers=headers, json=payload, timeout=15)
|
| 326 |
-
if response.status_code == 200:
|
| 327 |
-
result = response.json()
|
| 328 |
-
if isinstance(result, list) and len(result) > 0:
|
| 329 |
-
return result[0]['generated_text'].strip()
|
| 330 |
-
except:
|
| 331 |
-
continue
|
| 332 |
-
return None
|
| 333 |
-
|
| 334 |
-
class RetrievalSystem:
|
| 335 |
-
def __init__(self):
|
| 336 |
-
self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
|
| 337 |
-
self.question_embeddings = None
|
| 338 |
-
|
| 339 |
-
def compute_embeddings(self, questions: List[str]) -> None:
|
| 340 |
-
if questions:
|
| 341 |
-
print(f"正在為 {len(questions)} 個問題計算向量...")
|
| 342 |
-
self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True, show_progress_bar=False)
|
| 343 |
-
print("向量計算完成")
|
| 344 |
-
|
| 345 |
-
def retrieve_similar(self, user_question: str, top_k: int = 5) -> List[Dict]:
|
| 346 |
-
if self.question_embeddings is None or len(self.question_embeddings) == 0:
|
| 347 |
-
return []
|
| 348 |
-
try:
|
| 349 |
-
question_embedding = self.embedder.encode(user_question, convert_to_tensor=True)
|
| 350 |
-
hits = util.semantic_search(question_embedding, self.question_embeddings, top_k=top_k)
|
| 351 |
-
return hits[0] if hits and hits[0] else []
|
| 352 |
-
except Exception as e:
|
| 353 |
-
print(f"檢索錯誤: {e}")
|
| 354 |
-
return []
|
| 355 |
-
|
| 356 |
# ==================== 初始化系統 ====================
|
| 357 |
-
print("
|
| 358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
# ==================== Gradio界面 ====================
|
| 361 |
-
def process_query(user_question: str) -> Tuple[str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
sql_result, log_message = text_to_sql_system.generate_sql(user_question)
|
| 363 |
return sql_result, "✅ SQL生成完成", log_message
|
| 364 |
|
| 365 |
-
with gr.Blocks(title="
|
| 366 |
-
gr.Markdown("# 🚀
|
| 367 |
-
gr.Markdown("📊
|
| 368 |
|
| 369 |
with gr.Row():
|
| 370 |
question_input = gr.Textbox(
|
| 371 |
-
label="📝
|
| 372 |
-
placeholder="例如:2023
|
| 373 |
-
lines=
|
| 374 |
scale=4
|
| 375 |
)
|
| 376 |
submit_btn = gr.Button("🚀 生成SQL", variant="primary", scale=1)
|
| 377 |
|
| 378 |
-
with gr.
|
| 379 |
sql_output = gr.Code(
|
| 380 |
-
label="📊 生成的SQL",
|
| 381 |
language="sql",
|
| 382 |
-
lines=
|
| 383 |
)
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
submit_btn.click(
|
| 390 |
process_query,
|
| 391 |
inputs=question_input,
|
| 392 |
-
outputs=[sql_output,
|
| 393 |
)
|
| 394 |
|
| 395 |
if __name__ == "__main__":
|
| 396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
import numpy as np
|
| 13 |
|
| 14 |
# ==================== 配置區 ====================
|
| 15 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", None) # 建議從環境變數讀取
|
| 16 |
DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
|
| 17 |
+
SIMILARITY_THRESHOLD = 0.6
|
| 18 |
|
| 19 |
+
# 多個備用LLM模型 (注意:在當前邏輯中並未使用)
|
| 20 |
LLM_MODELS = [
|
| 21 |
"https://api-inference.huggingface.co/models/gpt2",
|
| 22 |
+
"https://api-inference.huggingface.co/models/distilgpt2",
|
| 23 |
"https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
|
| 24 |
]
|
| 25 |
|
| 26 |
print("=" * 60)
|
| 27 |
print("🤖 智能 Text-to-SQL 系統啟動中...")
|
| 28 |
+
print(f"📊 模式: 讀取全部數據(來自 {DATASET_REPO_ID})")
|
| 29 |
print("=" * 60)
|
| 30 |
|
| 31 |
+
# ==================== 獨立工具函數 (不依賴類別實例) ====================
|
| 32 |
def get_current_time():
|
| 33 |
+
"""獲取當前時間字串"""
|
| 34 |
return datetime.now().strftime("%H:%M:%S")
|
| 35 |
|
| 36 |
def validate_sql(sql_query: str) -> Dict:
|
| 37 |
+
"""驗證SQL語句的語法和安全性"""
|
| 38 |
+
if not sql_query or not sql_query.strip():
|
| 39 |
return {"valid": False, "issues": ["SQL語句為空"], "is_safe": False, "empty": True}
|
| 40 |
|
| 41 |
sql_clean = sql_query.strip()
|
|
|
|
| 48 |
# 檢查危險操作
|
| 49 |
dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
|
| 50 |
for keyword in dangerous_keywords:
|
| 51 |
+
if f" {keyword} " in f" {sql_upper} ":
|
| 52 |
security_issues.append(f"危險操作: {keyword}")
|
| 53 |
|
| 54 |
# 檢查基本語法
|
|
|
|
| 58 |
if "FROM" not in sql_upper:
|
| 59 |
security_issues.append("缺少FROM")
|
| 60 |
|
| 61 |
+
is_valid = not security_issues
|
| 62 |
+
is_safe = all('危險' not in issue for issue in security_issues)
|
| 63 |
+
|
| 64 |
return {
|
| 65 |
+
"valid": is_valid,
|
| 66 |
"issues": security_issues,
|
| 67 |
+
"is_safe": is_safe,
|
| 68 |
"empty": False
|
| 69 |
}
|
| 70 |
|
| 71 |
def analyze_question_type(question: str) -> Dict:
|
| 72 |
+
"""分析問題類型和關鍵詞"""
|
| 73 |
question_lower = question.lower()
|
| 74 |
|
| 75 |
analysis = {
|
|
|
|
| 83 |
|
| 84 |
# 檢測關鍵詞
|
| 85 |
keywords_sets = {
|
| 86 |
+
"sales": ["銷售", "業績", "金額", "收入", "sale", "revenue"],
|
| 87 |
+
"customer": ["客戶", "買家", "用戶", "customer", "client"],
|
| 88 |
+
"product": ["產品", "商品", "項目", "product", "item"],
|
| 89 |
+
"time": ["時間", "日期", "月份", "年", "月", "最近", "date", "month", "year"],
|
| 90 |
+
"report": ["報告", "完成", "份", "report", "complete"],
|
| 91 |
+
"count": ["多少", "幾個", "數量", "count", "how many"],
|
| 92 |
+
"comparison": ["比較", "vs", " versus", "對比", "相比"]
|
| 93 |
}
|
| 94 |
|
| 95 |
for category, keywords in keywords_sets.items():
|
| 96 |
+
if any(keyword in question_lower for keyword in keywords):
|
| 97 |
+
analysis["keywords"].append(category)
|
| 98 |
+
|
|
|
|
|
|
|
| 99 |
# 特殊檢測
|
| 100 |
+
analysis["has_count"] = "count" in analysis["keywords"]
|
| 101 |
+
analysis["has_date"] = "time" in analysis["keywords"]
|
| 102 |
analysis["has_group"] = any(word in question_lower for word in ["每", "各", "group", "每個"])
|
| 103 |
+
analysis["has_comparison"] = "comparison" in analysis["keywords"]
|
| 104 |
|
| 105 |
# 確定主要類型
|
| 106 |
if analysis["keywords"]:
|
|
|
|
| 108 |
|
| 109 |
return analysis
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
# ==================== 完整數據加載模塊 ====================
|
| 112 |
class CompleteDataLoader:
|
| 113 |
def __init__(self, hf_token: str):
|
|
|
|
| 124 |
raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
|
| 125 |
|
| 126 |
print("解析全部 messages 格式...")
|
| 127 |
+
total_count, empty_count, valid_count = 0, 0, 0
|
|
|
|
|
|
|
| 128 |
|
| 129 |
for item in raw_dataset:
|
| 130 |
try:
|
|
|
|
| 140 |
sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
|
| 141 |
if sql_match:
|
| 142 |
sql_query = sql_match.group(1).strip()
|
| 143 |
+
sql_query = re.sub(r'^sql\s*', '', sql_query, flags=re.IGNORECASE)
|
| 144 |
sql_query = re.sub(r'```sql|```', '', sql_query).strip()
|
| 145 |
else:
|
| 146 |
sql_query = assistant_content
|
|
|
|
| 159 |
empty_count += 1
|
| 160 |
if validation["valid"]:
|
| 161 |
valid_count += 1
|
| 162 |
+
except Exception:
|
|
|
|
| 163 |
continue
|
| 164 |
|
| 165 |
print(f"數據加載完成: 總數 {total_count}, 有效 {valid_count}, 空白 {empty_count}")
|
|
|
|
| 187 |
self.schema_data = {}
|
| 188 |
return False
|
| 189 |
|
| 190 |
+
# ==================== 檢索系統 ====================
|
| 191 |
+
class RetrievalSystem:
|
| 192 |
+
def __init__(self):
|
| 193 |
+
try:
|
| 194 |
+
self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
|
| 195 |
+
self.question_embeddings = None
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(f"SentenceTransformer 模型加載���敗: {e}")
|
| 198 |
+
self.embedder = None
|
| 199 |
+
|
| 200 |
+
def compute_embeddings(self, questions: List[str]) -> None:
|
| 201 |
+
if self.embedder and questions:
|
| 202 |
+
print(f"正在為 {len(questions)} 個問題計算向量...")
|
| 203 |
+
self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
|
| 204 |
+
print("向量計算完成")
|
| 205 |
+
|
| 206 |
+
def retrieve_similar(self, user_question: str, top_k: int = 5) -> List[Dict]:
|
| 207 |
+
if self.embedder is None or self.question_embeddings is None or len(self.question_embeddings) == 0:
|
| 208 |
+
return []
|
| 209 |
+
try:
|
| 210 |
+
question_embedding = self.embedder.encode(user_question, convert_to_tensor=True)
|
| 211 |
+
hits = util.semantic_search(question_embedding, self.question_embeddings, top_k=top_k)
|
| 212 |
+
return hits[0] if hits and hits[0] else []
|
| 213 |
+
except Exception as e:
|
| 214 |
+
print(f"檢索錯誤: {e}")
|
| 215 |
+
return []
|
| 216 |
+
|
| 217 |
# ==================== 主系統 ====================
|
| 218 |
class CompleteTextToSQLSystem:
|
| 219 |
def __init__(self, hf_token: str):
|
| 220 |
self.hf_token = hf_token
|
| 221 |
self.data_loader = CompleteDataLoader(hf_token)
|
| 222 |
self.retrieval_system = RetrievalSystem()
|
|
|
|
| 223 |
self.initialize_system()
|
| 224 |
|
| 225 |
def initialize_system(self):
|
|
|
|
| 229 |
self.data_loader.load_complete_dataset()
|
| 230 |
self.data_loader.load_schema()
|
| 231 |
|
| 232 |
+
# 為所有問題計算向量
|
| 233 |
if self.data_loader.questions:
|
| 234 |
self.retrieval_system.compute_embeddings(self.data_loader.questions)
|
| 235 |
|
| 236 |
print(f"系統初始化完成,載入問題總數: {len(self.data_loader.questions)}")
|
| 237 |
+
|
| 238 |
+
# ===== 輔助函數 (作為類別方法) =====
|
| 239 |
+
def get_available_tables(self) -> Dict:
|
| 240 |
+
"""從schema中獲取所有可用的表和欄位"""
|
| 241 |
+
if not self.data_loader.schema_data:
|
| 242 |
+
return {}
|
| 243 |
+
|
| 244 |
+
tables = {}
|
| 245 |
+
for table_name, columns_list in self.data_loader.schema_data.items():
|
| 246 |
+
if isinstance(columns_list, list):
|
| 247 |
+
column_names = [col["name"] for col in columns_list if "name" in col]
|
| 248 |
+
tables[table_name] = column_names
|
| 249 |
+
|
| 250 |
+
return tables
|
| 251 |
+
|
| 252 |
+
def extract_number(self, text: str, default: int = 10) -> int:
|
| 253 |
+
"""從文字中提取數字"""
|
| 254 |
+
numbers = re.findall(r'\d+', text)
|
| 255 |
+
return int(numbers[0]) if numbers else default
|
| 256 |
+
|
| 257 |
+
def generate_sql_from_question(self, question: str, analysis: Dict) -> str:
|
| 258 |
+
"""根據問題分析和真實Schema生成智能SQL"""
|
| 259 |
+
question_lower = question.lower()
|
| 260 |
+
available_tables = self.get_available_tables().keys()
|
| 261 |
+
|
| 262 |
+
# 1. 每月/每日完成數量 - 使用 JobTimeline 相關表
|
| 263 |
+
if any(kw in question_lower for kw in ["每月", "每日", "昨天", "完成"]) and analysis["has_count"]:
|
| 264 |
+
group_match = re.search(r'([a-z]組)', question_lower)
|
| 265 |
+
if group_match:
|
| 266 |
+
group = group_match.group(1).replace('組', '').upper()
|
| 267 |
+
group_mapping = {'A': 'TA', 'B': 'TB', 'C': 'TC', 'D': 'TD'}
|
| 268 |
+
table_suffix = group_mapping.get(group, 'TA')
|
| 269 |
+
table_name = f"JobTimeline_{table_suffix}"
|
| 270 |
+
|
| 271 |
+
if "昨天" in question_lower:
|
| 272 |
+
return f"SELECT COUNT(*) as 完成數量 FROM {table_name} WHERE DATE(end_time) = DATE('now','-1 day');"
|
| 273 |
+
elif "每月" in question_lower:
|
| 274 |
+
year_match = re.search(r'(\d{4})年?', question_lower)
|
| 275 |
+
year = year_match.group(1) if year_match else datetime.now().strftime('%Y')
|
| 276 |
+
return f"""SELECT strftime('%Y-%m', end_time) as 月份, COUNT(*) as 完成數量 FROM {table_name} WHERE strftime('%Y', end_time) = '{year}' AND end_time IS NOT NULL GROUP BY strftime('%Y-%m', end_time) ORDER BY 月份;"""
|
| 277 |
+
return "SELECT strftime('%Y-%m', jt.end_time) as 月份, COUNT(*) as 完成數量 FROM JobTimeline jt WHERE jt.end_time IS NOT NULL GROUP BY strftime('%Y-%m', jt.end_time) ORDER BY 月份;"
|
| 278 |
+
|
| 279 |
+
# 2. 評級分析 - 使用 TSR53SampleDescription.OverallRating
|
| 280 |
+
elif any(kw in question_lower for kw in ["評級", "rating", "等級"]) and "TSR53SampleDescription" in available_tables:
|
| 281 |
+
if any(kw in question_lower for kw in ["分佈", "統計", "多少"]):
|
| 282 |
+
return "SELECT OverallRating as 評級, COUNT(*) as 數量, ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM TSR53SampleDescription), 2) as 百分比 FROM TSR53SampleDescription WHERE OverallRating IS NOT NULL GROUP BY OverallRating ORDER BY 數量 DESC;"
|
| 283 |
+
elif "fail" in question_lower or "失敗" in question_lower:
|
| 284 |
+
return "SELECT JobNo as 工作單號, ApplicantName as 申請方, OverallRating as 評級 FROM TSR53SampleDescription WHERE OverallRating = 'Fail' ORDER BY JobNo;"
|
| 285 |
+
|
| 286 |
+
# 3. 金額相���查詢 - 使用 TSR53Invoice
|
| 287 |
+
elif any(kw in question_lower for kw in ["金額", "總額", "收入", "invoice"]) and any(kw in question_lower for kw in ["最高", "最大", "top"]):
|
| 288 |
+
limit_num = self.extract_number(question_lower, default=10)
|
| 289 |
+
return f"""WITH JobTotalAmount AS (SELECT JobNo, SUM(LocalAmount) AS TotalAmount FROM (SELECT DISTINCT JobNo, InvoiceCreditNoteNo, LocalAmount FROM TSR53Invoice WHERE LocalAmount IS NOT NULL) GROUP BY JobNo) SELECT jta.JobNo as 工作單號, sd.ApplicantName as 申請方, jta.TotalAmount as 總金額 FROM JobTotalAmount jta JOIN TSR53SampleDescription sd ON sd.JobNo = jta.JobNo ORDER BY jta.TotalAmount DESC LIMIT {limit_num};"""
|
| 290 |
+
|
| 291 |
+
# 4. 公司/客戶相關查詢
|
| 292 |
+
elif any(kw in question_lower for kw in ["公司", "客戶", "申請方", "付款方"]):
|
| 293 |
+
if any(kw in question_lower for kw in ["最多", "top", "排名"]):
|
| 294 |
+
return "SELECT ApplicantName as 申請方名稱, COUNT(*) as 工作單數量 FROM TSR53SampleDescription WHERE ApplicantName IS NOT NULL GROUP BY ApplicantName ORDER BY 工作單數量 DESC LIMIT 10;"
|
| 295 |
+
return "SELECT ApplicantName as 申請方, InvoiceToName as 付款方, COUNT(*) as 工作單數量 FROM TSR53SampleDescription WHERE ApplicantName IS NOT NULL GROUP BY ApplicantName, InvoiceToName ORDER BY 工作單數量 DESC;"
|
| 296 |
+
|
| 297 |
+
# ... 其他規則可以繼續添加 ...
|
| 298 |
+
|
| 299 |
+
# 預設查詢 - 顯示基本工作單資訊
|
| 300 |
+
return "SELECT JobNo as 工作單號, ApplicantName as 申請方, InvoiceToName as 付款方, OverallRating as 評級 FROM TSR53SampleDescription LIMIT 20;"
|
| 301 |
+
|
| 302 |
+
def repair_empty_sql(self, original_sql: str, user_question: str, similar_question: str) -> str:
|
| 303 |
+
"""修復空白或無效的SQL"""
|
| 304 |
+
validation = validate_sql(original_sql)
|
| 305 |
+
if not validation["valid"]:
|
| 306 |
+
analysis = analyze_question_type(user_question)
|
| 307 |
+
repaired_sql = self.generate_sql_from_question(user_question, analysis)
|
| 308 |
+
return f"-- 根據類似問題 '{similar_question}' (原SQL無效) 自動生成的查詢\n{repaired_sql}"
|
| 309 |
+
return original_sql
|
| 310 |
+
|
| 311 |
def generate_sql(self, user_question: str) -> Tuple[str, str]:
|
| 312 |
+
"""主流程:生成SQL查詢"""
|
| 313 |
log_messages = [f"⏰ {get_current_time()} 開始處理"]
|
| 314 |
|
| 315 |
+
if not user_question or not user_question.strip():
|
| 316 |
return "請輸入您的問題。", "錯誤: 問題為空"
|
| 317 |
|
| 318 |
+
# 1. 檢索最相似的問題
|
| 319 |
if self.data_loader.questions:
|
| 320 |
hits = self.retrieval_system.retrieve_similar(user_question)
|
| 321 |
|
|
|
|
| 325 |
corpus_id = best_hit['corpus_id']
|
| 326 |
similar_question = self.data_loader.questions[corpus_id]
|
| 327 |
original_sql = self.data_loader.sql_answers[corpus_id]
|
|
|
|
| 328 |
|
| 329 |
+
log_messages.append(f"🔍 檢索到最相似問題: '{similar_question}'")
|
| 330 |
+
log_messages.append(f"📊 相似度: {similarity_score:.3f}")
|
| 331 |
|
| 332 |
if similarity_score > SIMILARITY_THRESHOLD:
|
| 333 |
+
repaired_sql = self.repair_empty_sql(original_sql, user_question, similar_question)
|
| 334 |
+
log_messages.append(f"✅ 相似度高於閾值 {SIMILARITY_THRESHOLD},採用檢索結果。")
|
| 335 |
+
return repaired_sql, "\n".join(log_messages)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
else:
|
| 337 |
+
log_messages.append(f"ℹ️ 相似度低於閾值 {SIMILARITY_THRESHOLD},轉為智能生成���")
|
| 338 |
+
|
| 339 |
# 2. 如果檢索失敗或相似度不足,智能生成SQL
|
| 340 |
+
log_messages.append("🤖 找不到高相似度結果,啟用智能生成規則...")
|
| 341 |
analysis = analyze_question_type(user_question)
|
| 342 |
+
intelligent_sql = self.generate_sql_from_question(user_question, analysis)
|
| 343 |
|
| 344 |
+
log_messages.append(f"📋 問題分析: {analysis['type']} 類型, 關鍵詞: {analysis['keywords']}")
|
| 345 |
+
log_messages.append("✅ 智能生成完成。")
|
| 346 |
|
| 347 |
return intelligent_sql, "\n".join(log_messages)
|
| 348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
# ==================== 初始化系統 ====================
|
| 350 |
+
print("準備初始化 Text-to-SQL 系統...")
|
| 351 |
+
# 檢查 HF_TOKEN 是否存在
|
| 352 |
+
if HF_TOKEN is None:
|
| 353 |
+
print("\n" + "="*60)
|
| 354 |
+
print("⚠️ 警告: Hugging Face Token 未設置。")
|
| 355 |
+
print("請在環境變數中設定 HF_TOKEN 才能從私人數據集下載資料。")
|
| 356 |
+
print("="*60 + "\n")
|
| 357 |
+
# 這裡可以選擇退出或繼續,但下載會失敗
|
| 358 |
+
text_to_sql_system = None
|
| 359 |
+
else:
|
| 360 |
+
text_to_sql_system = CompleteTextToSQLSystem(HF_TOKEN)
|
| 361 |
|
| 362 |
# ==================== Gradio界面 ====================
|
| 363 |
+
def process_query(user_question: str) -> Tuple[str, str, str]:
|
| 364 |
+
if text_to_sql_system is None:
|
| 365 |
+
error_msg = "系統因缺少 Hugging Face Token 而未成功初始化。"
|
| 366 |
+
return "系統未初始化", error_msg, error_msg
|
| 367 |
+
|
| 368 |
sql_result, log_message = text_to_sql_system.generate_sql(user_question)
|
| 369 |
return sql_result, "✅ SQL生成完成", log_message
|
| 370 |
|
| 371 |
+
with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
|
| 372 |
+
gr.Markdown("# 🚀 智慧 Text-to-SQL 系統")
|
| 373 |
+
gr.Markdown("📊 **模式**: 讀取雲端數據集並結合「檢索」與「規則生成」兩種模式。")
|
| 374 |
|
| 375 |
with gr.Row():
|
| 376 |
question_input = gr.Textbox(
|
| 377 |
+
label="📝 請在此輸入您的問題",
|
| 378 |
+
placeholder="例如:2023年每月完成多少份報告? 或 哪個客戶的訂單總金額最高?",
|
| 379 |
+
lines=3,
|
| 380 |
scale=4
|
| 381 |
)
|
| 382 |
submit_btn = gr.Button("🚀 生成SQL", variant="primary", scale=1)
|
| 383 |
|
| 384 |
+
with gr.Accordion("🔍 結果與日誌", open=True):
|
| 385 |
sql_output = gr.Code(
|
| 386 |
+
label="📊 生成的SQL查詢",
|
| 387 |
language="sql",
|
| 388 |
+
lines=8
|
| 389 |
)
|
| 390 |
+
status_output = gr.Textbox(label="🔍 執行狀態", interactive=False)
|
| 391 |
+
log_output = gr.Textbox(label="📋 詳細日誌", lines=5, interactive=False)
|
| 392 |
+
|
| 393 |
+
# 預設範例
|
| 394 |
+
gr.Examples(
|
| 395 |
+
examples=[
|
| 396 |
+
"昨天完成了多少個工作單?",
|
| 397 |
+
"A組每月完成數量是多少?",
|
| 398 |
+
"哪個申請方的失敗評級最多?",
|
| 399 |
+
"找出總金額最高的10筆訂單",
|
| 400 |
+
"統計所有評級的分佈"
|
| 401 |
+
],
|
| 402 |
+
inputs=question_input
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
submit_btn.click(
|
| 406 |
process_query,
|
| 407 |
inputs=question_input,
|
| 408 |
+
outputs=[sql_output, status_output, log_output]
|
| 409 |
)
|
| 410 |
|
| 411 |
if __name__ == "__main__":
|
| 412 |
+
print("Gradio 介面啟動中...")
|
| 413 |
+
if text_to_sql_system is None:
|
| 414 |
+
print("無法啟動 Gradio,因為系統初始化失敗。")
|
| 415 |
+
else:
|
| 416 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|