Paul720810 commited on
Commit
afb724a
·
verified ·
1 Parent(s): 352a657

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -16
app.py CHANGED
@@ -16,9 +16,15 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None) # 建議從環境變數讀取
16
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
17
  SIMILARITY_THRESHOLD = 0.65 # 適度提高閾值,確保檢索到的問題意圖更一致
18
 
 
 
 
 
19
  print("=" * 60)
20
  print("🤖 智能 Text-to-SQL 系統啟動中...")
21
  print(f"📊 模式: 讀取全部數據(來自 {DATASET_REPO_ID})")
 
 
22
  print("=" * 60)
23
 
24
  # ==================== 獨立工具函數 (不依賴類別實例) ====================
@@ -98,6 +104,7 @@ class CompleteDataLoader:
98
 
99
  successful_loads = 0
100
  total_items = len(raw_dataset)
 
101
 
102
  for idx, item in enumerate(raw_dataset):
103
  try:
@@ -105,22 +112,57 @@ class CompleteDataLoader:
105
  user_content = item['messages'][0]['content']
106
  assistant_content = item['messages'][1]['content']
107
 
 
108
  question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
109
- question = question_match.group(1).strip() if question_match else user_content
 
 
 
 
110
 
 
111
  sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
112
- sql_query = sql_match.group(1).strip() if sql_match else assistant_content
 
 
 
 
 
 
 
 
 
 
113
  sql_query = re.sub(r'```sql|```', '', sql_query).strip()
114
 
115
- if question and sql_query: # 只加載有效的問答對
116
- self.questions.append(question)
117
- self.sql_answers.append(sql_query)
118
- successful_loads += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  except Exception as e:
120
- print(f"跳過第 {idx} 項資料,錯誤: {e}")
 
 
121
  continue
122
 
123
  print(f"數據加載完成: 成功載入 {successful_loads}/{total_items} 項")
 
124
  return successful_loads > 0
125
  except Exception as e:
126
  print(f"數據集加載失敗: {e}")
@@ -141,17 +183,44 @@ class CompleteDataLoader:
141
  class RetrievalSystem:
142
  def __init__(self):
143
  try:
144
- self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
 
 
 
145
  self.question_embeddings = None
 
146
  except Exception as e:
147
- print(f"SentenceTransformer 模型加載失敗: {e}")
148
  self.embedder = None
149
 
150
  def compute_embeddings(self, questions: List[str]):
151
  if self.embedder and questions:
152
  print(f"正在為 {len(questions)} 個問題計算向量...")
153
- self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
154
- print("向量計算完成")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  def retrieve_similar(self, user_question: str, top_k: int = 1) -> List[Dict]:
157
  if self.embedder is None or self.question_embeddings is None: return []
@@ -336,8 +405,13 @@ def process_query(user_question: str) -> Tuple[str, str, str]:
336
  return sql_result, "✅ 處理完成", log_message
337
 
338
  with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
339
- gr.Markdown("# 🚀 智慧 Text-to-SQL 系統 (進階修復版)")
 
 
 
 
340
  gr.Markdown("📊 **模式**: 結合「檢索驗證」與「意圖導向生成」,即使資料庫範本有誤也能提供準確查詢。")
 
341
 
342
  with gr.Row():
343
  question_input = gr.Textbox(
@@ -353,7 +427,7 @@ with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
353
  status_output = gr.Textbox(label="🔍 執行狀態", interactive=False)
354
  log_output = gr.Textbox(label="📋 詳細日誌", lines=6, interactive=False)
355
 
356
- # 改進的範例
357
  gr.Examples(
358
  examples=[
359
  "2024年每月完成多少份報告?",
@@ -363,7 +437,8 @@ with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
363
  "A組昨天完成了多少個測試項目?",
364
  "2024年Q1期間評級為Fail且總金額超過10000的工作單"
365
  ],
366
- inputs=question_input
 
367
  )
368
 
369
  # 綁定事件
@@ -382,7 +457,29 @@ with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
382
  if __name__ == "__main__":
383
  if text_to_sql_system:
384
  print("Gradio 介面啟動中...")
385
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  else:
387
- print("無法啟動 Gradio,因為系統初始化失敗。")
 
 
388
 
 
16
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
17
  SIMILARITY_THRESHOLD = 0.65 # 適度提高閾值,確保檢索到的問題意圖更一致
18
 
19
+ # 雲端環境檢測
20
+ IS_SPACES = os.environ.get("SPACE_ID") is not None
21
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
  print("=" * 60)
24
  print("🤖 智能 Text-to-SQL 系統啟動中...")
25
  print(f"📊 模式: 讀取全部數據(來自 {DATASET_REPO_ID})")
26
+ print(f"🌐 環境: {'Hugging Face Spaces' if IS_SPACES else '本地環境'}")
27
+ print(f"💻 設備: {DEVICE}")
28
  print("=" * 60)
29
 
30
  # ==================== 獨立工具函數 (不依賴類別實例) ====================
 
104
 
105
  successful_loads = 0
106
  total_items = len(raw_dataset)
107
+ skipped_reasons = {"empty_question": 0, "empty_sql": 0, "parse_error": 0, "invalid_format": 0}
108
 
109
  for idx, item in enumerate(raw_dataset):
110
  try:
 
112
  user_content = item['messages'][0]['content']
113
  assistant_content = item['messages'][1]['content']
114
 
115
+ # 改進的問題提取邏輯
116
  question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
117
+ if question_match:
118
+ question = question_match.group(1).strip()
119
+ else:
120
+ # 如果沒有找到「指令:」格式,嘗試直接使用內容
121
+ question = user_content.strip()
122
 
123
+ # 改進的SQL提取邏輯
124
  sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
125
+ if sql_match:
126
+ sql_query = sql_match.group(1).strip()
127
+ else:
128
+ # 如果沒有找到「SQL查詢:」格式,嘗試提取SQL代碼塊
129
+ sql_block_match = re.search(r'```sql\s*(.*?)\s*```', assistant_content, re.DOTALL)
130
+ if sql_block_match:
131
+ sql_query = sql_block_match.group(1).strip()
132
+ else:
133
+ sql_query = assistant_content.strip()
134
+
135
+ # 清理SQL查詢
136
  sql_query = re.sub(r'```sql|```', '', sql_query).strip()
137
 
138
+ # 驗證數據質量
139
+ if not question or len(question.strip()) < 3:
140
+ skipped_reasons["empty_question"] += 1
141
+ continue
142
+
143
+ if not sql_query or len(sql_query.strip()) < 10:
144
+ skipped_reasons["empty_sql"] += 1
145
+ continue
146
+
147
+ # 基本SQL驗證
148
+ if "SELECT" not in sql_query.upper():
149
+ skipped_reasons["invalid_format"] += 1
150
+ continue
151
+
152
+ self.questions.append(question)
153
+ self.sql_answers.append(sql_query)
154
+ successful_loads += 1
155
+ else:
156
+ skipped_reasons["invalid_format"] += 1
157
+
158
  except Exception as e:
159
+ skipped_reasons["parse_error"] += 1
160
+ if idx < 5: # 只顯示前5個錯誤
161
+ print(f"跳過第 {idx} 項資料,錯誤: {e}")
162
  continue
163
 
164
  print(f"數據加載完成: 成功載入 {successful_loads}/{total_items} 項")
165
+ print(f"跳過原因統計: 問題為空({skipped_reasons['empty_question']}) | SQL為空({skipped_reasons['empty_sql']}) | 格式錯誤({skipped_reasons['invalid_format']}) | 解析錯誤({skipped_reasons['parse_error']})")
166
  return successful_loads > 0
167
  except Exception as e:
168
  print(f"數據集加載失敗: {e}")
 
183
  class RetrievalSystem:
184
  def __init__(self):
185
  try:
186
+ # 根據環境選擇設備
187
+ device = DEVICE if 'DEVICE' in globals() else 'cpu'
188
+ print(f"🔧 初始化 SentenceTransformer (設備: {device})...")
189
+ self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device=device)
190
  self.question_embeddings = None
191
+ print("✅ SentenceTransformer 模型加載成功")
192
  except Exception as e:
193
+ print(f"SentenceTransformer 模型加載失敗: {e}")
194
  self.embedder = None
195
 
196
  def compute_embeddings(self, questions: List[str]):
197
  if self.embedder and questions:
198
  print(f"正在為 {len(questions)} 個問題計算向量...")
199
+ try:
200
+ # 雲端環境優化:分批處理以節省記憶體
201
+ batch_size = 32 if IS_SPACES else 64
202
+ self.question_embeddings = self.embedder.encode(
203
+ questions,
204
+ convert_to_tensor=True,
205
+ show_progress_bar=True,
206
+ batch_size=batch_size
207
+ )
208
+ print("向量計算完成")
209
+ except Exception as e:
210
+ print(f"向量計算失敗: {e}")
211
+ # 降級處理:使用更小的批次大小
212
+ try:
213
+ print("嘗試使用較小批次大小重新計算...")
214
+ self.question_embeddings = self.embedder.encode(
215
+ questions,
216
+ convert_to_tensor=True,
217
+ show_progress_bar=True,
218
+ batch_size=16
219
+ )
220
+ print("向量計算完成(降級模式)")
221
+ except Exception as e2:
222
+ print(f"向量計算徹底失敗: {e2}")
223
+ self.question_embeddings = None
224
 
225
  def retrieve_similar(self, user_question: str, top_k: int = 1) -> List[Dict]:
226
  if self.embedder is None or self.question_embeddings is None: return []
 
405
  return sql_result, "✅ 處理完成", log_message
406
 
407
  with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
408
+ # 環境資訊顯示
409
+ env_info = f"🌐 運行環境: {'Hugging Face Spaces' if IS_SPACES else '本地環境'} | 💻 設備: {DEVICE}"
410
+ system_status = f"📊 已載入 {len(text_to_sql_system.data_loader.questions) if text_to_sql_system else 0} 個問答範例"
411
+
412
+ gr.Markdown("# 🚀 智慧 Text-to-SQL 系統 (雲端版)")
413
  gr.Markdown("📊 **模式**: 結合「檢索驗證」與「意圖導向生成」,即使資料庫範本有誤也能提供準確查詢。")
414
+ gr.Markdown(f"ℹ️ {env_info} | {system_status}")
415
 
416
  with gr.Row():
417
  question_input = gr.Textbox(
 
427
  status_output = gr.Textbox(label="🔍 執行狀態", interactive=False)
428
  log_output = gr.Textbox(label="📋 詳細日誌", lines=6, interactive=False)
429
 
430
+ # 雲端環境優化的範例
431
  gr.Examples(
432
  examples=[
433
  "2024年每月完成多少份報告?",
 
437
  "A組昨天完成了多少個測試項目?",
438
  "2024年Q1期間評級為Fail且總金額超過10000的工作單"
439
  ],
440
+ inputs=question_input,
441
+ label="💡 範例問題 (點擊試用)"
442
  )
443
 
444
  # 綁定事件
 
457
  if __name__ == "__main__":
458
  if text_to_sql_system:
459
  print("Gradio 介面啟動中...")
460
+
461
+ # 根據環境選擇啟動參數
462
+ if IS_SPACES:
463
+ # Hugging Face Spaces 環境
464
+ print("🌐 在 Hugging Face Spaces 環境中啟動...")
465
+ demo.launch(
466
+ server_name="0.0.0.0",
467
+ server_port=7860,
468
+ share=False,
469
+ show_error=True,
470
+ quiet=False
471
+ )
472
+ else:
473
+ # 本地環境
474
+ print("🏠 在本地環境中啟動...")
475
+ demo.launch(
476
+ server_name="127.0.0.1",
477
+ server_port=7860,
478
+ share=True, # 本地環境可以選擇分享
479
+ show_error=True
480
+ )
481
  else:
482
+ print("無法啟動 Gradio,因為系統初始化失敗。")
483
+ if IS_SPACES:
484
+ print("💡 請檢查 Hugging Face Spaces 的環境變數設置。")
485