Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,9 +16,15 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None) # 建議從環境變數讀取
|
|
| 16 |
DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
|
| 17 |
SIMILARITY_THRESHOLD = 0.65 # 適度提高閾值,確保檢索到的問題意圖更一致
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
print("=" * 60)
|
| 20 |
print("🤖 智能 Text-to-SQL 系統啟動中...")
|
| 21 |
print(f"📊 模式: 讀取全部數據(來自 {DATASET_REPO_ID})")
|
|
|
|
|
|
|
| 22 |
print("=" * 60)
|
| 23 |
|
| 24 |
# ==================== 獨立工具函數 (不依賴類別實例) ====================
|
|
@@ -98,6 +104,7 @@ class CompleteDataLoader:
|
|
| 98 |
|
| 99 |
successful_loads = 0
|
| 100 |
total_items = len(raw_dataset)
|
|
|
|
| 101 |
|
| 102 |
for idx, item in enumerate(raw_dataset):
|
| 103 |
try:
|
|
@@ -105,22 +112,57 @@ class CompleteDataLoader:
|
|
| 105 |
user_content = item['messages'][0]['content']
|
| 106 |
assistant_content = item['messages'][1]['content']
|
| 107 |
|
|
|
|
| 108 |
question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
|
|
|
| 111 |
sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
sql_query = re.sub(r'```sql|```', '', sql_query).strip()
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
except Exception as e:
|
| 120 |
-
|
|
|
|
|
|
|
| 121 |
continue
|
| 122 |
|
| 123 |
print(f"數據加載完成: 成功載入 {successful_loads}/{total_items} 項")
|
|
|
|
| 124 |
return successful_loads > 0
|
| 125 |
except Exception as e:
|
| 126 |
print(f"數據集加載失敗: {e}")
|
|
@@ -141,17 +183,44 @@ class CompleteDataLoader:
|
|
| 141 |
class RetrievalSystem:
|
| 142 |
def __init__(self):
|
| 143 |
try:
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
| 145 |
self.question_embeddings = None
|
|
|
|
| 146 |
except Exception as e:
|
| 147 |
-
print(f"SentenceTransformer 模型加載失敗: {e}")
|
| 148 |
self.embedder = None
|
| 149 |
|
| 150 |
def compute_embeddings(self, questions: List[str]):
|
| 151 |
if self.embedder and questions:
|
| 152 |
print(f"正在為 {len(questions)} 個問題計算向量...")
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
def retrieve_similar(self, user_question: str, top_k: int = 1) -> List[Dict]:
|
| 157 |
if self.embedder is None or self.question_embeddings is None: return []
|
|
@@ -336,8 +405,13 @@ def process_query(user_question: str) -> Tuple[str, str, str]:
|
|
| 336 |
return sql_result, "✅ 處理完成", log_message
|
| 337 |
|
| 338 |
with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
gr.Markdown("📊 **模式**: 結合「檢索驗證」與「意圖導向生成」,即使資料庫範本有誤也能提供準確查詢。")
|
|
|
|
| 341 |
|
| 342 |
with gr.Row():
|
| 343 |
question_input = gr.Textbox(
|
|
@@ -353,7 +427,7 @@ with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
|
|
| 353 |
status_output = gr.Textbox(label="🔍 執行狀態", interactive=False)
|
| 354 |
log_output = gr.Textbox(label="📋 詳細日誌", lines=6, interactive=False)
|
| 355 |
|
| 356 |
-
#
|
| 357 |
gr.Examples(
|
| 358 |
examples=[
|
| 359 |
"2024年每月完成多少份報告?",
|
|
@@ -363,7 +437,8 @@ with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
|
|
| 363 |
"A組昨天完成了多少個測試項目?",
|
| 364 |
"2024年Q1期間評級為Fail且總金額超過10000的工作單"
|
| 365 |
],
|
| 366 |
-
inputs=question_input
|
|
|
|
| 367 |
)
|
| 368 |
|
| 369 |
# 綁定事件
|
|
@@ -382,7 +457,29 @@ with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
|
|
| 382 |
if __name__ == "__main__":
|
| 383 |
if text_to_sql_system:
|
| 384 |
print("Gradio 介面啟動中...")
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
else:
|
| 387 |
-
print("無法啟動 Gradio,因為系統初始化失敗。")
|
|
|
|
|
|
|
| 388 |
|
|
|
|
| 16 |
DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
|
| 17 |
SIMILARITY_THRESHOLD = 0.65 # 適度提高閾值,確保檢索到的問題意圖更一致
|
| 18 |
|
| 19 |
+
# 雲端環境檢測
|
| 20 |
+
IS_SPACES = os.environ.get("SPACE_ID") is not None
|
| 21 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
+
|
| 23 |
print("=" * 60)
|
| 24 |
print("🤖 智能 Text-to-SQL 系統啟動中...")
|
| 25 |
print(f"📊 模式: 讀取全部數據(來自 {DATASET_REPO_ID})")
|
| 26 |
+
print(f"🌐 環境: {'Hugging Face Spaces' if IS_SPACES else '本地環境'}")
|
| 27 |
+
print(f"💻 設備: {DEVICE}")
|
| 28 |
print("=" * 60)
|
| 29 |
|
| 30 |
# ==================== 獨立工具函數 (不依賴類別實例) ====================
|
|
|
|
| 104 |
|
| 105 |
successful_loads = 0
|
| 106 |
total_items = len(raw_dataset)
|
| 107 |
+
skipped_reasons = {"empty_question": 0, "empty_sql": 0, "parse_error": 0, "invalid_format": 0}
|
| 108 |
|
| 109 |
for idx, item in enumerate(raw_dataset):
|
| 110 |
try:
|
|
|
|
| 112 |
user_content = item['messages'][0]['content']
|
| 113 |
assistant_content = item['messages'][1]['content']
|
| 114 |
|
| 115 |
+
# 改進的問題提取邏輯
|
| 116 |
question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
|
| 117 |
+
if question_match:
|
| 118 |
+
question = question_match.group(1).strip()
|
| 119 |
+
else:
|
| 120 |
+
# 如果沒有找到「指令:」格式,嘗試直接使用內容
|
| 121 |
+
question = user_content.strip()
|
| 122 |
|
| 123 |
+
# 改進的SQL提取邏輯
|
| 124 |
sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
|
| 125 |
+
if sql_match:
|
| 126 |
+
sql_query = sql_match.group(1).strip()
|
| 127 |
+
else:
|
| 128 |
+
# 如果沒有找到「SQL查詢:」格式,嘗試提取SQL代碼塊
|
| 129 |
+
sql_block_match = re.search(r'```sql\s*(.*?)\s*```', assistant_content, re.DOTALL)
|
| 130 |
+
if sql_block_match:
|
| 131 |
+
sql_query = sql_block_match.group(1).strip()
|
| 132 |
+
else:
|
| 133 |
+
sql_query = assistant_content.strip()
|
| 134 |
+
|
| 135 |
+
# 清理SQL查詢
|
| 136 |
sql_query = re.sub(r'```sql|```', '', sql_query).strip()
|
| 137 |
|
| 138 |
+
# 驗證數據質量
|
| 139 |
+
if not question or len(question.strip()) < 3:
|
| 140 |
+
skipped_reasons["empty_question"] += 1
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
if not sql_query or len(sql_query.strip()) < 10:
|
| 144 |
+
skipped_reasons["empty_sql"] += 1
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
# 基本SQL驗證
|
| 148 |
+
if "SELECT" not in sql_query.upper():
|
| 149 |
+
skipped_reasons["invalid_format"] += 1
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
self.questions.append(question)
|
| 153 |
+
self.sql_answers.append(sql_query)
|
| 154 |
+
successful_loads += 1
|
| 155 |
+
else:
|
| 156 |
+
skipped_reasons["invalid_format"] += 1
|
| 157 |
+
|
| 158 |
except Exception as e:
|
| 159 |
+
skipped_reasons["parse_error"] += 1
|
| 160 |
+
if idx < 5: # 只顯示前5個錯誤
|
| 161 |
+
print(f"跳過第 {idx} 項資料,錯誤: {e}")
|
| 162 |
continue
|
| 163 |
|
| 164 |
print(f"數據加載完成: 成功載入 {successful_loads}/{total_items} 項")
|
| 165 |
+
print(f"跳過原因統計: 問題為空({skipped_reasons['empty_question']}) | SQL為空({skipped_reasons['empty_sql']}) | 格式錯誤({skipped_reasons['invalid_format']}) | 解析錯誤({skipped_reasons['parse_error']})")
|
| 166 |
return successful_loads > 0
|
| 167 |
except Exception as e:
|
| 168 |
print(f"數據集加載失敗: {e}")
|
|
|
|
| 183 |
class RetrievalSystem:
|
| 184 |
def __init__(self):
|
| 185 |
try:
|
| 186 |
+
# 根據環境選擇設備
|
| 187 |
+
device = DEVICE if 'DEVICE' in globals() else 'cpu'
|
| 188 |
+
print(f"🔧 初始化 SentenceTransformer (設備: {device})...")
|
| 189 |
+
self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device=device)
|
| 190 |
self.question_embeddings = None
|
| 191 |
+
print("✅ SentenceTransformer 模型加載成功")
|
| 192 |
except Exception as e:
|
| 193 |
+
print(f"❌ SentenceTransformer 模型加載失敗: {e}")
|
| 194 |
self.embedder = None
|
| 195 |
|
| 196 |
def compute_embeddings(self, questions: List[str]):
|
| 197 |
if self.embedder and questions:
|
| 198 |
print(f"正在為 {len(questions)} 個問題計算向量...")
|
| 199 |
+
try:
|
| 200 |
+
# 雲端環境優化:分批處理以節省記憶體
|
| 201 |
+
batch_size = 32 if IS_SPACES else 64
|
| 202 |
+
self.question_embeddings = self.embedder.encode(
|
| 203 |
+
questions,
|
| 204 |
+
convert_to_tensor=True,
|
| 205 |
+
show_progress_bar=True,
|
| 206 |
+
batch_size=batch_size
|
| 207 |
+
)
|
| 208 |
+
print("向量計算完成")
|
| 209 |
+
except Exception as e:
|
| 210 |
+
print(f"向量計算失敗: {e}")
|
| 211 |
+
# 降級處理:使用更小的批次大小
|
| 212 |
+
try:
|
| 213 |
+
print("嘗試使用較小批次大小重新計算...")
|
| 214 |
+
self.question_embeddings = self.embedder.encode(
|
| 215 |
+
questions,
|
| 216 |
+
convert_to_tensor=True,
|
| 217 |
+
show_progress_bar=True,
|
| 218 |
+
batch_size=16
|
| 219 |
+
)
|
| 220 |
+
print("向量計算完成(降級模式)")
|
| 221 |
+
except Exception as e2:
|
| 222 |
+
print(f"向量計算徹底失敗: {e2}")
|
| 223 |
+
self.question_embeddings = None
|
| 224 |
|
| 225 |
def retrieve_similar(self, user_question: str, top_k: int = 1) -> List[Dict]:
|
| 226 |
if self.embedder is None or self.question_embeddings is None: return []
|
|
|
|
| 405 |
return sql_result, "✅ 處理完成", log_message
|
| 406 |
|
| 407 |
with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
|
| 408 |
+
# 環境資訊顯示
|
| 409 |
+
env_info = f"🌐 運行環境: {'Hugging Face Spaces' if IS_SPACES else '本地環境'} | 💻 設備: {DEVICE}"
|
| 410 |
+
system_status = f"📊 已載入 {len(text_to_sql_system.data_loader.questions) if text_to_sql_system else 0} 個問答範例"
|
| 411 |
+
|
| 412 |
+
gr.Markdown("# 🚀 智慧 Text-to-SQL 系統 (雲端版)")
|
| 413 |
gr.Markdown("📊 **模式**: 結合「檢索驗證」與「意圖導向生成」,即使資料庫範本有誤也能提供準確查詢。")
|
| 414 |
+
gr.Markdown(f"ℹ️ {env_info} | {system_status}")
|
| 415 |
|
| 416 |
with gr.Row():
|
| 417 |
question_input = gr.Textbox(
|
|
|
|
| 427 |
status_output = gr.Textbox(label="🔍 執行狀態", interactive=False)
|
| 428 |
log_output = gr.Textbox(label="📋 詳細日誌", lines=6, interactive=False)
|
| 429 |
|
| 430 |
+
# 雲端環境優化的範例
|
| 431 |
gr.Examples(
|
| 432 |
examples=[
|
| 433 |
"2024年每月完成多少份報告?",
|
|
|
|
| 437 |
"A組昨天完成了多少個測試項目?",
|
| 438 |
"2024年Q1期間評級為Fail且總金額超過10000的工作單"
|
| 439 |
],
|
| 440 |
+
inputs=question_input,
|
| 441 |
+
label="💡 範例問題 (點擊試用)"
|
| 442 |
)
|
| 443 |
|
| 444 |
# 綁定事件
|
|
|
|
| 457 |
if __name__ == "__main__":
|
| 458 |
if text_to_sql_system:
|
| 459 |
print("Gradio 介面啟動中...")
|
| 460 |
+
|
| 461 |
+
# 根據環境選擇啟動參數
|
| 462 |
+
if IS_SPACES:
|
| 463 |
+
# Hugging Face Spaces 環境
|
| 464 |
+
print("🌐 在 Hugging Face Spaces 環境中啟動...")
|
| 465 |
+
demo.launch(
|
| 466 |
+
server_name="0.0.0.0",
|
| 467 |
+
server_port=7860,
|
| 468 |
+
share=False,
|
| 469 |
+
show_error=True,
|
| 470 |
+
quiet=False
|
| 471 |
+
)
|
| 472 |
+
else:
|
| 473 |
+
# 本地環境
|
| 474 |
+
print("🏠 在本地環境中啟動...")
|
| 475 |
+
demo.launch(
|
| 476 |
+
server_name="127.0.0.1",
|
| 477 |
+
server_port=7860,
|
| 478 |
+
share=True, # 本地環境可以選擇分享
|
| 479 |
+
show_error=True
|
| 480 |
+
)
|
| 481 |
else:
|
| 482 |
+
print("❌ 無法啟動 Gradio,因為系統初始化失敗。")
|
| 483 |
+
if IS_SPACES:
|
| 484 |
+
print("💡 請檢查 Hugging Face Spaces 的環境變數設置。")
|
| 485 |
|