Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,53 +15,40 @@ from typing import List, Dict, Tuple, Optional
|
|
| 15 |
# ==================== 配置區 ====================
|
| 16 |
HF_TOKEN = os.environ.get("HF_TOKEN", "您的_HuggingFace_Token")
|
| 17 |
DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
|
| 18 |
-
SIMILARITY_THRESHOLD = 0.75
|
| 19 |
|
| 20 |
-
# 多個備用LLM
|
| 21 |
LLM_MODELS = [
|
| 22 |
"https://api-inference.huggingface.co/models/gpt2",
|
| 23 |
"https://api-inference.huggingface.co/models/distilgpt2",
|
| 24 |
"https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
|
| 25 |
]
|
| 26 |
|
| 27 |
-
# 數據庫連接配置(可選)
|
| 28 |
-
DB_CONFIG = {
|
| 29 |
-
"enabled": False, # 設置為True啟用真實數據庫連接
|
| 30 |
-
"path": "您的數據庫路徑.db",
|
| 31 |
-
"test_queries": True # 是否啟用SQL測試功能
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
print("=" * 50)
|
| 35 |
print("🚀 智能 Text-to-SQL 系統啟動中...")
|
| 36 |
print("=" * 50)
|
| 37 |
|
| 38 |
# ==================== 工具函數 ====================
|
| 39 |
def get_current_time():
|
| 40 |
-
"""獲取當前時間字符串"""
|
| 41 |
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 42 |
|
| 43 |
-
def safe_json_load(data, default=None):
|
| 44 |
-
"""安全的JSON解析"""
|
| 45 |
-
try:
|
| 46 |
-
return json.loads(data) if isinstance(data, str) else data
|
| 47 |
-
except (json.JSONDecodeError, TypeError):
|
| 48 |
-
return default
|
| 49 |
-
|
| 50 |
def validate_sql(sql_query: str) -> Dict:
|
| 51 |
"""驗證SQL語句的安全性"""
|
| 52 |
security_issues = []
|
| 53 |
|
| 54 |
# 檢查危險操作
|
| 55 |
dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
|
|
|
|
|
|
|
| 56 |
for keyword in dangerous_keywords:
|
| 57 |
-
if f" {keyword} " in
|
| 58 |
security_issues.append(f"發現危險操作: {keyword}")
|
| 59 |
|
| 60 |
# 檢查基本語法
|
| 61 |
-
if "SELECT" not in
|
| 62 |
security_issues.append("缺少SELECT語句")
|
| 63 |
|
| 64 |
-
if "FROM" not in
|
| 65 |
security_issues.append("缺少FROM子句")
|
| 66 |
|
| 67 |
return {
|
|
@@ -70,30 +57,28 @@ def validate_sql(sql_query: str) -> Dict:
|
|
| 70 |
"is_safe": len([i for i in security_issues if '危險' in i]) == 0
|
| 71 |
}
|
| 72 |
|
| 73 |
-
def
|
| 74 |
-
"""
|
| 75 |
-
if not
|
| 76 |
-
return
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
except Exception as e:
|
| 96 |
-
return False, f"❌ SQL執行錯誤: {str(e)}"
|
| 97 |
|
| 98 |
# ==================== 數據加載模塊 ====================
|
| 99 |
class DataLoader:
|
|
@@ -126,8 +111,14 @@ class DataLoader:
|
|
| 126 |
sql_query = sql_match.group(1).strip()
|
| 127 |
sql_query = re.sub(r'^sql\s*', '', sql_query)
|
| 128 |
sql_query = re.sub(r'```sql|```', '', sql_query).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
else:
|
| 130 |
-
sql_query = assistant_content
|
| 131 |
|
| 132 |
self.questions.append(question)
|
| 133 |
self.sql_answers.append(sql_query)
|
|
@@ -154,7 +145,7 @@ class DataLoader:
|
|
| 154 |
token=self.hf_token
|
| 155 |
)
|
| 156 |
with open(schema_file_path, 'r', encoding='utf-8') as f:
|
| 157 |
-
self.schema_data =
|
| 158 |
print("Schema加載成功")
|
| 159 |
return True
|
| 160 |
except Exception as e:
|
|
@@ -184,7 +175,7 @@ class LLMClient:
|
|
| 184 |
self.hf_token = hf_token
|
| 185 |
|
| 186 |
def call_llm_api(self, prompt: str, model_urls: List[str] = LLM_MODELS) -> Optional[str]:
|
| 187 |
-
"""調用LLM API
|
| 188 |
headers = {"Authorization": f"Bearer {self.hf_token}"}
|
| 189 |
payload = {
|
| 190 |
"inputs": prompt,
|
|
@@ -203,12 +194,10 @@ class LLMClient:
|
|
| 203 |
result = response.json()
|
| 204 |
if isinstance(result, list) and len(result) > 0:
|
| 205 |
generated_text = result[0]['generated_text'].strip()
|
| 206 |
-
# 清理輸出
|
| 207 |
generated_text = re.sub(r'^```sql|```$', '', generated_text).strip()
|
| 208 |
return generated_text
|
| 209 |
|
| 210 |
except Exception as e:
|
| 211 |
-
print(f"模型 {model_url} 調用失敗: {e}")
|
| 212 |
continue
|
| 213 |
|
| 214 |
return None
|
|
@@ -249,25 +238,23 @@ class TextToSQLSystem:
|
|
| 249 |
self.llm_client = LLMClient(hf_token)
|
| 250 |
self.retrieval_system = RetrievalSystem()
|
| 251 |
|
| 252 |
-
# 初始化組件
|
| 253 |
self.initialize_system()
|
| 254 |
|
| 255 |
def initialize_system(self):
|
| 256 |
"""初始化系統組件"""
|
| 257 |
print("正在初始化系統組件...")
|
| 258 |
|
| 259 |
-
# 加載數據
|
| 260 |
self.data_loader.load_dataset()
|
| 261 |
self.data_loader.load_schema()
|
| 262 |
-
|
| 263 |
-
# 初始化檢索系統
|
| 264 |
self.retrieval_system.compute_embeddings(self.data_loader.questions)
|
| 265 |
-
|
| 266 |
self.schema_context = self.data_loader.build_schema_context()
|
|
|
|
| 267 |
print("系統初始化完成")
|
|
|
|
|
|
|
| 268 |
|
| 269 |
def generate_sql(self, user_question: str) -> Tuple[str, str]:
|
| 270 |
-
"""生成SQL
|
| 271 |
log_messages = [f"🕒 開始處理: {get_current_time()}"]
|
| 272 |
|
| 273 |
if not user_question or user_question.strip() == "":
|
|
@@ -281,20 +268,23 @@ class TextToSQLSystem:
|
|
| 281 |
best_hit = hits[0]
|
| 282 |
similarity_score = best_hit['score']
|
| 283 |
similar_question = self.data_loader.questions[best_hit['corpus_id']]
|
|
|
|
| 284 |
|
| 285 |
log_messages.append(f"🔍 檢索到相似問題: '{similar_question}'")
|
| 286 |
log_messages.append(f"📊 相似度: {similarity_score:.3f}")
|
| 287 |
|
| 288 |
if similarity_score > SIMILARITY_THRESHOLD:
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
log_messages.append(f"
|
| 296 |
-
|
| 297 |
-
|
|
|
|
|
|
|
| 298 |
else:
|
| 299 |
log_messages.append(f"ℹ️ 相似度低於閾值 {SIMILARITY_THRESHOLD}")
|
| 300 |
|
|
@@ -305,17 +295,14 @@ class TextToSQLSystem:
|
|
| 305 |
generated_sql = self.llm_client.call_llm_api(prompt)
|
| 306 |
|
| 307 |
if generated_sql:
|
| 308 |
-
#
|
| 309 |
-
generated_sql = re.sub(r'^```sql|```$', '', generated_sql).strip()
|
| 310 |
validation = validate_sql(generated_sql)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
-
|
| 313 |
-
log_messages.append("✅ LLM生成成功")
|
| 314 |
-
if validation["issues"]:
|
| 315 |
-
log_messages.append(f"ℹ️ 驗證提示: {', '.join(validation['issues'])}")
|
| 316 |
-
else:
|
| 317 |
-
log_messages.append("⚠️ LLM生成可能存在问题")
|
| 318 |
-
|
| 319 |
return generated_sql, "\n".join(log_messages)
|
| 320 |
else:
|
| 321 |
# 3. 備用方案
|
|
@@ -334,7 +321,7 @@ class TextToSQLSystem:
|
|
| 334 |
|
| 335 |
要求:
|
| 336 |
1. 只輸出SQL語句
|
| 337 |
-
2.
|
| 338 |
3. 使用正確的語法
|
| 339 |
|
| 340 |
SQL查詢:"""
|
|
@@ -346,11 +333,9 @@ SQL查詢:"""
|
|
| 346 |
if any(kw in user_question_lower for kw in ['銷售', '業績', '金額', '收入']):
|
| 347 |
return "SELECT product_name, SUM(sales_amount) as total_sales FROM sales GROUP BY product_name ORDER BY total_sales DESC LIMIT 10;"
|
| 348 |
elif any(kw in user_question_lower for kw in ['客戶', '買家', '用戶']):
|
| 349 |
-
return "SELECT customer_name, COUNT(*) as order_count
|
| 350 |
elif any(kw in user_question_lower for kw in ['時間', '日期', '最近', '月份']):
|
| 351 |
-
return "SELECT strftime('%Y-%m', order_date) as month, COUNT(*) as orders
|
| 352 |
-
elif any(kw in user_question_lower for kw in ['產品', '商品', '項目']):
|
| 353 |
-
return "SELECT product_name, category, stock_quantity, price FROM products WHERE stock_quantity > 0 ORDER BY price DESC;"
|
| 354 |
else:
|
| 355 |
return "SELECT '請重試或提供更詳細的問題' AS status;"
|
| 356 |
|
|
@@ -359,70 +344,52 @@ print("正在初始化Text-to-SQL系統...")
|
|
| 359 |
text_to_sql_system = TextToSQLSystem(HF_TOKEN)
|
| 360 |
|
| 361 |
# ==================== Gradio界面 ====================
|
| 362 |
-
def process_query(user_question: str
|
| 363 |
"""處理用戶查詢"""
|
| 364 |
sql_result, log_message = text_to_sql_system.generate_sql(user_question)
|
| 365 |
|
| 366 |
-
#
|
|
|
|
| 367 |
debug_info = ""
|
| 368 |
-
validation = validate_sql(sql_result)
|
| 369 |
|
| 370 |
-
if not
|
| 371 |
-
debug_info = "❌ SQL驗證失敗:\n" + "\n".join(
|
|
|
|
| 372 |
else:
|
| 373 |
-
debug_info = "✅ SQL
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
debug_info += "\nℹ️ 提示: " + ", ".join(validation["issues"])
|
| 377 |
-
|
| 378 |
-
# 如果啟用測試功能
|
| 379 |
-
if test_query and DB_CONFIG["test_queries"]:
|
| 380 |
-
success, test_result = execute_test_query(sql_result)
|
| 381 |
-
debug_info += f"\n\n🔧 測試結果:\n{test_result}"
|
| 382 |
|
| 383 |
return sql_result, debug_info, log_message
|
| 384 |
|
| 385 |
# 創建界面
|
| 386 |
-
with gr.Blocks(
|
| 387 |
-
title="智能Text-to-SQL系統",
|
| 388 |
-
theme=gr.themes.Soft(),
|
| 389 |
-
css="""
|
| 390 |
-
.gradio-container { max-width: 1000px; margin: 0 auto; }
|
| 391 |
-
.success { color: green; }
|
| 392 |
-
.warning { color: orange; }
|
| 393 |
-
.error { color: red; }
|
| 394 |
-
"""
|
| 395 |
-
) as demo:
|
| 396 |
|
| 397 |
gr.Markdown("# 🚀 智能 Text-to-SQL 系統")
|
| 398 |
gr.Markdown("輸入自然語言問題,自動生成並驗證SQL查詢")
|
| 399 |
|
| 400 |
with gr.Row():
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
submit_btn = gr.Button("🚀 生成SQL", variant="primary")
|
| 411 |
-
test_btn = gr.Button("🔧 測試SQL", variant="secondary")
|
| 412 |
-
clear_btn = gr.Button("🗑️ 清除", variant="secondary")
|
| 413 |
|
| 414 |
with gr.Row():
|
| 415 |
sql_output = gr.Code(
|
| 416 |
label="📊 生成的SQL",
|
| 417 |
language="sql",
|
| 418 |
-
lines=6
|
| 419 |
-
interactive=True
|
| 420 |
)
|
| 421 |
|
| 422 |
with gr.Row():
|
| 423 |
debug_output = gr.Textbox(
|
| 424 |
-
label="🔍 SQL
|
| 425 |
-
lines=
|
| 426 |
interactive=False
|
| 427 |
)
|
| 428 |
|
|
@@ -433,28 +400,9 @@ with gr.Blocks(
|
|
| 433 |
interactive=False
|
| 434 |
)
|
| 435 |
|
| 436 |
-
# 示例問題
|
| 437 |
-
gr.Examples(
|
| 438 |
-
examples=[
|
| 439 |
-
"2024年銷售額最高的5個產品",
|
| 440 |
-
"最近30天每個客戶的訂單數量",
|
| 441 |
-
"庫存不足的商品列表",
|
| 442 |
-
"比較2023年和2024年的月度銷售額",
|
| 443 |
-
"付款不及時的客戶統計"
|
| 444 |
-
],
|
| 445 |
-
inputs=question_input,
|
| 446 |
-
label="💡 示例問題"
|
| 447 |
-
)
|
| 448 |
-
|
| 449 |
# 事件處理
|
| 450 |
submit_btn.click(
|
| 451 |
-
fn=
|
| 452 |
-
inputs=question_input,
|
| 453 |
-
outputs=[sql_output, debug_output, log_output]
|
| 454 |
-
)
|
| 455 |
-
|
| 456 |
-
test_btn.click(
|
| 457 |
-
fn=lambda q: process_query(q, True),
|
| 458 |
inputs=question_input,
|
| 459 |
outputs=[sql_output, debug_output, log_output]
|
| 460 |
)
|
|
@@ -470,12 +418,6 @@ if __name__ == "__main__":
|
|
| 470 |
print("=" * 50)
|
| 471 |
print("🌐 啟動Gradio Web界面...")
|
| 472 |
print("📍 本地訪問: http://localhost:7860")
|
| 473 |
-
print("🔄 如果需要公網訪問,設置 share=True")
|
| 474 |
print("=" * 50)
|
| 475 |
|
| 476 |
-
demo.launch(
|
| 477 |
-
server_name="0.0.0.0",
|
| 478 |
-
server_port=7860,
|
| 479 |
-
share=False,
|
| 480 |
-
show_error=True
|
| 481 |
-
)
|
|
|
|
| 15 |
# ==================== 配置區 ====================
|
| 16 |
HF_TOKEN = os.environ.get("HF_TOKEN", "您的_HuggingFace_Token")
|
| 17 |
DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
|
| 18 |
+
SIMILARITY_THRESHOLD = 0.75
|
| 19 |
|
| 20 |
+
# 多個備用LLM模型
|
| 21 |
LLM_MODELS = [
|
| 22 |
"https://api-inference.huggingface.co/models/gpt2",
|
| 23 |
"https://api-inference.huggingface.co/models/distilgpt2",
|
| 24 |
"https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
|
| 25 |
]
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
print("=" * 50)
|
| 28 |
print("🚀 智能 Text-to-SQL 系統啟動中...")
|
| 29 |
print("=" * 50)
|
| 30 |
|
| 31 |
# ==================== 工具函數 ====================
|
| 32 |
def get_current_time():
|
|
|
|
| 33 |
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
def validate_sql(sql_query: str) -> Dict:
|
| 36 |
"""驗證SQL語句的安全性"""
|
| 37 |
security_issues = []
|
| 38 |
|
| 39 |
# 檢查危險操作
|
| 40 |
dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
|
| 41 |
+
sql_upper = sql_query.upper()
|
| 42 |
+
|
| 43 |
for keyword in dangerous_keywords:
|
| 44 |
+
if f" {keyword} " in sql_upper:
|
| 45 |
security_issues.append(f"發現危險操作: {keyword}")
|
| 46 |
|
| 47 |
# 檢查基本語法
|
| 48 |
+
if "SELECT" not in sql_upper:
|
| 49 |
security_issues.append("缺少SELECT語句")
|
| 50 |
|
| 51 |
+
if "FROM" not in sql_upper:
|
| 52 |
security_issues.append("缺少FROM子句")
|
| 53 |
|
| 54 |
return {
|
|
|
|
| 57 |
"is_safe": len([i for i in security_issues if '危險' in i]) == 0
|
| 58 |
}
|
| 59 |
|
| 60 |
+
def repair_sql(sql_query: str) -> str:
|
| 61 |
+
"""修復有問題的SQL語句"""
|
| 62 |
+
if not sql_query or sql_query.strip() == "":
|
| 63 |
+
return "SELECT 'SQL語句為空' AS error;"
|
| 64 |
|
| 65 |
+
# 清理SQL
|
| 66 |
+
sql_clean = re.sub(r'^```sql|```$', '', sql_query).strip()
|
| 67 |
+
|
| 68 |
+
# 檢查是否已經是完整SQL
|
| 69 |
+
if "SELECT" in sql_clean.upper() and "FROM" in sql_clean.upper():
|
| 70 |
+
return sql_clean
|
| 71 |
+
|
| 72 |
+
# 如果只有SELECT部分
|
| 73 |
+
if "SELECT" in sql_clean.upper() and "FROM" not in sql_clean.upper():
|
| 74 |
+
return sql_clean + " FROM appropriate_table WHERE 1=1;"
|
| 75 |
+
|
| 76 |
+
# 如果只有FROM部分
|
| 77 |
+
if "FROM" in sql_clean.upper() and "SELECT" not in sql_clean.upper():
|
| 78 |
+
return "SELECT * " + sql_clean
|
| 79 |
+
|
| 80 |
+
# 如果什麼都沒有,提供默認查詢
|
| 81 |
+
return "SELECT '請檢查SQL語法' AS status;"
|
|
|
|
|
|
|
| 82 |
|
| 83 |
# ==================== 數據加載模塊 ====================
|
| 84 |
class DataLoader:
|
|
|
|
| 111 |
sql_query = sql_match.group(1).strip()
|
| 112 |
sql_query = re.sub(r'^sql\s*', '', sql_query)
|
| 113 |
sql_query = re.sub(r'```sql|```', '', sql_query).strip()
|
| 114 |
+
|
| 115 |
+
# 驗證並修復SQL
|
| 116 |
+
validation = validate_sql(sql_query)
|
| 117 |
+
if not validation["valid"]:
|
| 118 |
+
print(f"發現有問題的SQL,將進行修復: {sql_query}")
|
| 119 |
+
sql_query = repair_sql(sql_query)
|
| 120 |
else:
|
| 121 |
+
sql_query = repair_sql(assistant_content)
|
| 122 |
|
| 123 |
self.questions.append(question)
|
| 124 |
self.sql_answers.append(sql_query)
|
|
|
|
| 145 |
token=self.hf_token
|
| 146 |
)
|
| 147 |
with open(schema_file_path, 'r', encoding='utf-8') as f:
|
| 148 |
+
self.schema_data = json.load(f)
|
| 149 |
print("Schema加載成功")
|
| 150 |
return True
|
| 151 |
except Exception as e:
|
|
|
|
| 175 |
self.hf_token = hf_token
|
| 176 |
|
| 177 |
def call_llm_api(self, prompt: str, model_urls: List[str] = LLM_MODELS) -> Optional[str]:
|
| 178 |
+
"""調用LLM API"""
|
| 179 |
headers = {"Authorization": f"Bearer {self.hf_token}"}
|
| 180 |
payload = {
|
| 181 |
"inputs": prompt,
|
|
|
|
| 194 |
result = response.json()
|
| 195 |
if isinstance(result, list) and len(result) > 0:
|
| 196 |
generated_text = result[0]['generated_text'].strip()
|
|
|
|
| 197 |
generated_text = re.sub(r'^```sql|```$', '', generated_text).strip()
|
| 198 |
return generated_text
|
| 199 |
|
| 200 |
except Exception as e:
|
|
|
|
| 201 |
continue
|
| 202 |
|
| 203 |
return None
|
|
|
|
| 238 |
self.llm_client = LLMClient(hf_token)
|
| 239 |
self.retrieval_system = RetrievalSystem()
|
| 240 |
|
|
|
|
| 241 |
self.initialize_system()
|
| 242 |
|
| 243 |
def initialize_system(self):
|
| 244 |
"""初始化系統組件"""
|
| 245 |
print("正在初始化系統組件...")
|
| 246 |
|
|
|
|
| 247 |
self.data_loader.load_dataset()
|
| 248 |
self.data_loader.load_schema()
|
|
|
|
|
|
|
| 249 |
self.retrieval_system.compute_embeddings(self.data_loader.questions)
|
|
|
|
| 250 |
self.schema_context = self.data_loader.build_schema_context()
|
| 251 |
+
|
| 252 |
print("系統初始化完成")
|
| 253 |
+
print(f"可用問題數量: {len(self.data_loader.questions)}")
|
| 254 |
+
print(f"Schema表數量: {len(self.data_loader.schema_data)}")
|
| 255 |
|
| 256 |
def generate_sql(self, user_question: str) -> Tuple[str, str]:
|
| 257 |
+
"""生成SQL查詢"""
|
| 258 |
log_messages = [f"🕒 開始處理: {get_current_time()}"]
|
| 259 |
|
| 260 |
if not user_question or user_question.strip() == "":
|
|
|
|
| 268 |
best_hit = hits[0]
|
| 269 |
similarity_score = best_hit['score']
|
| 270 |
similar_question = self.data_loader.questions[best_hit['corpus_id']]
|
| 271 |
+
original_sql = self.data_loader.sql_answers[best_hit['corpus_id']]
|
| 272 |
|
| 273 |
log_messages.append(f"🔍 檢索到相似問題: '{similar_question}'")
|
| 274 |
log_messages.append(f"📊 相似度: {similarity_score:.3f}")
|
| 275 |
|
| 276 |
if similarity_score > SIMILARITY_THRESHOLD:
|
| 277 |
+
# 驗證並可能修復SQL
|
| 278 |
+
validation = validate_sql(original_sql)
|
| 279 |
+
if not validation["valid"]:
|
| 280 |
+
log_messages.append(f"⚠️ 原始SQL有問題: {', '.join(validation['issues'])}")
|
| 281 |
+
log_messages.append("🛠️ 正在修復SQL...")
|
| 282 |
+
repaired_sql = repair_sql(original_sql)
|
| 283 |
+
log_messages.append(f"✅ 修復完成")
|
| 284 |
+
return repaired_sql, "\n".join(log_messages)
|
| 285 |
+
else:
|
| 286 |
+
log_messages.append(f"✅ 相似度 > {SIMILARITY_THRESHOLD},直接返回")
|
| 287 |
+
return original_sql, "\n".join(log_messages)
|
| 288 |
else:
|
| 289 |
log_messages.append(f"ℹ️ 相似度低於閾值 {SIMILARITY_THRESHOLD}")
|
| 290 |
|
|
|
|
| 295 |
generated_sql = self.llm_client.call_llm_api(prompt)
|
| 296 |
|
| 297 |
if generated_sql:
|
| 298 |
+
# 驗證生成的SQL
|
|
|
|
| 299 |
validation = validate_sql(generated_sql)
|
| 300 |
+
if not validation["valid"]:
|
| 301 |
+
log_messages.append(f"⚠️ LLM生成的SQL有問題: {', '.join(validation['issues'])}")
|
| 302 |
+
log_messages.append("🛠️ 正在修復SQL...")
|
| 303 |
+
generated_sql = repair_sql(generated_sql)
|
| 304 |
|
| 305 |
+
log_messages.append("✅ SQL生成完成")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
return generated_sql, "\n".join(log_messages)
|
| 307 |
else:
|
| 308 |
# 3. 備用方案
|
|
|
|
| 321 |
|
| 322 |
要求:
|
| 323 |
1. 只輸出SQL語句
|
| 324 |
+
2. 必須包含SELECT和FROM
|
| 325 |
3. 使用正確的語法
|
| 326 |
|
| 327 |
SQL查詢:"""
|
|
|
|
| 333 |
if any(kw in user_question_lower for kw in ['銷售', '業績', '金額', '收入']):
|
| 334 |
return "SELECT product_name, SUM(sales_amount) as total_sales FROM sales GROUP BY product_name ORDER BY total_sales DESC LIMIT 10;"
|
| 335 |
elif any(kw in user_question_lower for kw in ['客戶', '買家', '用戶']):
|
| 336 |
+
return "SELECT customer_name, COUNT(*) as order_count FROM orders GROUP BY customer_name ORDER BY order_count DESC;"
|
| 337 |
elif any(kw in user_question_lower for kw in ['時間', '日期', '最近', '月份']):
|
| 338 |
+
return "SELECT strftime('%Y-%m', order_date) as month, COUNT(*) as orders FROM orders GROUP BY month ORDER BY month DESC;"
|
|
|
|
|
|
|
| 339 |
else:
|
| 340 |
return "SELECT '請重試或提供更詳細的問題' AS status;"
|
| 341 |
|
|
|
|
| 344 |
text_to_sql_system = TextToSQLSystem(HF_TOKEN)
|
| 345 |
|
| 346 |
# ==================== Gradio界面 ====================
|
| 347 |
+
def process_query(user_question: str) -> Tuple[str, str]:
|
| 348 |
"""處理用戶查詢"""
|
| 349 |
sql_result, log_message = text_to_sql_system.generate_sql(user_question)
|
| 350 |
|
| 351 |
+
# 最終驗證
|
| 352 |
+
final_validation = validate_sql(sql_result)
|
| 353 |
debug_info = ""
|
|
|
|
| 354 |
|
| 355 |
+
if not final_validation["valid"]:
|
| 356 |
+
debug_info = "❌ 最終SQL驗證失敗:\n" + "\n".join(final_validation["issues"])
|
| 357 |
+
debug_info += "\n🛠️ 已嘗試自動修復,但仍存在问题"
|
| 358 |
else:
|
| 359 |
+
debug_info = "✅ 最終SQL驗證通過"
|
| 360 |
+
if final_validation["issues"]:
|
| 361 |
+
debug_info += "\nℹ️ 提示: " + ", ".join(final_validation["issues"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
|
| 363 |
return sql_result, debug_info, log_message
|
| 364 |
|
| 365 |
# 創建界面
|
| 366 |
+
with gr.Blocks(title="智能Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
gr.Markdown("# 🚀 智能 Text-to-SQL 系統")
|
| 369 |
gr.Markdown("輸入自然語言問題,自動生成並驗證SQL查詢")
|
| 370 |
|
| 371 |
with gr.Row():
|
| 372 |
+
question_input = gr.Textbox(
|
| 373 |
+
label="📝 您的問題",
|
| 374 |
+
placeholder="例如:查詢2024年銷售額最高的產品",
|
| 375 |
+
lines=2
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
with gr.Row():
|
| 379 |
+
submit_btn = gr.Button("🚀 生成SQL", variant="primary")
|
| 380 |
+
clear_btn = gr.Button("🗑️ 清除", variant="secondary")
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
with gr.Row():
|
| 383 |
sql_output = gr.Code(
|
| 384 |
label="📊 生成的SQL",
|
| 385 |
language="sql",
|
| 386 |
+
lines=6
|
|
|
|
| 387 |
)
|
| 388 |
|
| 389 |
with gr.Row():
|
| 390 |
debug_output = gr.Textbox(
|
| 391 |
+
label="🔍 SQL驗證信息",
|
| 392 |
+
lines=3,
|
| 393 |
interactive=False
|
| 394 |
)
|
| 395 |
|
|
|
|
| 400 |
interactive=False
|
| 401 |
)
|
| 402 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
# 事件處理
|
| 404 |
submit_btn.click(
|
| 405 |
+
fn=process_query,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
inputs=question_input,
|
| 407 |
outputs=[sql_output, debug_output, log_output]
|
| 408 |
)
|
|
|
|
| 418 |
print("=" * 50)
|
| 419 |
print("🌐 啟動Gradio Web界面...")
|
| 420 |
print("📍 本地訪問: http://localhost:7860")
|
|
|
|
| 421 |
print("=" * 50)
|
| 422 |
|
| 423 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|