Paul720810 commited on
Commit
5fc665e
·
verified ·
1 Parent(s): 9943c6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -174
app.py CHANGED
@@ -12,29 +12,30 @@ from typing import List, Dict, Tuple, Optional
12
  import numpy as np
13
 
14
  # ==================== 配置區 ====================
15
- HF_TOKEN = os.environ.get("HF_TOKEN", "您的_HuggingFace_Token")
16
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
17
- SIMILARITY_THRESHOLD = 0.75
18
 
19
- # 多個備用LLM模型
20
  LLM_MODELS = [
21
  "https://api-inference.huggingface.co/models/gpt2",
22
- "https://api-inference.huggingface.co/models/distilgpt2",
23
  "https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
24
  ]
25
 
26
  print("=" * 60)
27
  print("🤖 智能 Text-to-SQL 系統啟動中...")
28
- print("📊 模式: 讀取全部4276條數據(包含空白SQL)")
29
  print("=" * 60)
30
 
31
- # ==================== 增強工具函數 ====================
32
  def get_current_time():
 
33
  return datetime.now().strftime("%H:%M:%S")
34
 
35
  def validate_sql(sql_query: str) -> Dict:
36
- """驗證SQL語句的安全性"""
37
- if not sql_query or sql_query.strip() == "":
38
  return {"valid": False, "issues": ["SQL語句為空"], "is_safe": False, "empty": True}
39
 
40
  sql_clean = sql_query.strip()
@@ -47,7 +48,7 @@ def validate_sql(sql_query: str) -> Dict:
47
  # 檢查危險操作
48
  dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
49
  for keyword in dangerous_keywords:
50
- if f" {keyword} " in sql_upper:
51
  security_issues.append(f"危險操作: {keyword}")
52
 
53
  # 檢查基本語法
@@ -57,15 +58,18 @@ def validate_sql(sql_query: str) -> Dict:
57
  if "FROM" not in sql_upper:
58
  security_issues.append("缺少FROM")
59
 
 
 
 
60
  return {
61
- "valid": len(security_issues) == 0,
62
  "issues": security_issues,
63
- "is_safe": len([i for i in security_issues if '危險' in i]) == 0,
64
  "empty": False
65
  }
66
 
67
  def analyze_question_type(question: str) -> Dict:
68
- """分析問題類型"""
69
  question_lower = question.lower()
70
 
71
  analysis = {
@@ -79,26 +83,24 @@ def analyze_question_type(question: str) -> Dict:
79
 
80
  # 檢測關鍵詞
81
  keywords_sets = {
82
- "sales": ["銷售", "業績", "金額", "收入", "sale", "revenue", "金額"],
83
- "customer": ["客戶", "買家", "用戶", "customer", "client", "買家"],
84
- "product": ["產品", "商品", "項目", "product", "item", "產品"],
85
- "time": ["時間", "日期", "月份", "年", "月", "最近", "date", "month", "year", "時間"],
86
- "report": ["報告", "完成", "份", "report", "complete", "報告"],
87
- "count": ["多少", "幾個", "數量", "count", "how many", "多少"],
88
- "comparison": ["比較", "vs", " versus", "對比", "相比", "比較"]
89
  }
90
 
91
  for category, keywords in keywords_sets.items():
92
- for keyword in keywords:
93
- if keyword in question_lower:
94
- if category not in analysis["keywords"]:
95
- analysis["keywords"].append(category)
96
-
97
  # 特殊檢測
98
- analysis["has_count"] = any(kw in question_lower for kw in keywords_sets["count"])
99
- analysis["has_date"] = any(kw in question_lower for kw in keywords_sets["time"])
100
  analysis["has_group"] = any(word in question_lower for word in ["每", "各", "group", "每個"])
101
- analysis["has_comparison"] = any(kw in question_lower for kw in keywords_sets["comparison"])
102
 
103
  # 確定主要類型
104
  if analysis["keywords"]:
@@ -106,57 +108,6 @@ def analyze_question_type(question: str) -> Dict:
106
 
107
  return analysis
108
 
109
- def generate_sql_from_question(question: str, analysis: Dict) -> str:
110
- """根據問題分析生成智能SQL"""
111
- question_lower = question.lower()
112
- question_type = analysis["type"]
113
-
114
- # 針對常見問題模式的SQL生成
115
- if "每月" in question_lower and ("完成" in question_lower or "報告" in question_lower):
116
- year_match = re.search(r'(\d{4})年', question_lower)
117
- year = year_match.group(1) if year_match else "2023"
118
- return f"SELECT strftime('%Y-%m', completion_date) as month, COUNT(*) as report_count FROM reports WHERE strftime('%Y', completion_date) = '{year}' GROUP BY month ORDER BY month;"
119
-
120
- elif "銷售" in question_lower and ("最高" in question_lower or "最好" in question_lower):
121
- return "SELECT product_name, SUM(sales_amount) as total_sales FROM sales GROUP BY product_name ORDER BY total_sales DESC LIMIT 10;"
122
-
123
- elif "客戶" in question_lower and ("訂單" in question_lower or "購買" in question_lower):
124
- return "SELECT customer_name, COUNT(*) as order_count, SUM(order_amount) as total_spent FROM orders GROUP BY customer_name ORDER BY total_spent DESC;"
125
-
126
- elif "比較" in question_lower and ("年" in question_lower or "年份" in question_lower):
127
- return "SELECT strftime('%Y', order_date) as year, COUNT(*) as order_count, SUM(order_amount) as yearly_revenue FROM orders GROUP BY year ORDER BY year;"
128
-
129
- elif "庫存" in question_lower and ("不足" in question_lower or "缺少" in question_lower):
130
- return "SELECT product_name, stock_quantity FROM products WHERE stock_quantity < 10 ORDER BY stock_quantity ASC;"
131
-
132
- # 根據分析結果生成通用SQL
133
- if analysis["has_count"] and analysis["has_group"] and analysis["has_date"]:
134
- return "SELECT strftime('%Y-%m', date_column) as period, COUNT(*) as item_count FROM appropriate_table GROUP BY period ORDER BY period;"
135
-
136
- elif analysis["has_count"] and analysis["has_group"]:
137
- return "SELECT category_column, COUNT(*) as count FROM appropriate_table GROUP BY category_column ORDER BY count DESC;"
138
-
139
- elif analysis["has_count"]:
140
- return "SELECT COUNT(*) as total_count FROM appropriate_table;"
141
-
142
- elif analysis["has_group"]:
143
- return "SELECT group_column, AVG(value_column) as average_value FROM appropriate_table GROUP BY group_column;"
144
-
145
- else:
146
- return "SELECT * FROM appropriate_table LIMIT 10;"
147
-
148
- def repair_empty_sql(original_sql: str, user_question: str, similar_question: str) -> str:
149
- """修復空白SQL"""
150
- if not original_sql or original_sql.strip() == "":
151
- # 分析問題並生成合適的SQL
152
- analysis = analyze_question_type(user_question)
153
- repaired_sql = generate_sql_from_question(user_question, analysis)
154
-
155
- # 添加註釋說明這是修復的SQL
156
- return f"-- 根據類似問題 '{similar_question}' 修復生成的SQL\n{repaired_sql}"
157
-
158
- return original_sql
159
-
160
  # ==================== 完整數據加載模塊 ====================
161
  class CompleteDataLoader:
162
  def __init__(self, hf_token: str):
@@ -173,9 +124,7 @@ class CompleteDataLoader:
173
  raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
174
 
175
  print("解析全部 messages 格式...")
176
- total_count = 0
177
- empty_count = 0
178
- valid_count = 0
179
 
180
  for item in raw_dataset:
181
  try:
@@ -191,7 +140,7 @@ class CompleteDataLoader:
191
  sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
192
  if sql_match:
193
  sql_query = sql_match.group(1).strip()
194
- sql_query = re.sub(r'^sql\s*', '', sql_query)
195
  sql_query = re.sub(r'```sql|```', '', sql_query).strip()
196
  else:
197
  sql_query = assistant_content
@@ -210,8 +159,7 @@ class CompleteDataLoader:
210
  empty_count += 1
211
  if validation["valid"]:
212
  valid_count += 1
213
-
214
- except Exception as e:
215
  continue
216
 
217
  print(f"數據加載完成: 總數 {total_count}, 有效 {valid_count}, 空白 {empty_count}")
@@ -239,13 +187,39 @@ class CompleteDataLoader:
239
  self.schema_data = {}
240
  return False
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  # ==================== 主系統 ====================
243
  class CompleteTextToSQLSystem:
244
  def __init__(self, hf_token: str):
245
  self.hf_token = hf_token
246
  self.data_loader = CompleteDataLoader(hf_token)
247
  self.retrieval_system = RetrievalSystem()
248
-
249
  self.initialize_system()
250
 
251
  def initialize_system(self):
@@ -255,20 +229,93 @@ class CompleteTextToSQLSystem:
255
  self.data_loader.load_complete_dataset()
256
  self.data_loader.load_schema()
257
 
258
- # 為所有問題計算向量(包括空白SQL的)
259
  if self.data_loader.questions:
260
  self.retrieval_system.compute_embeddings(self.data_loader.questions)
261
 
262
  print(f"系統初始化完成,載入問題總數: {len(self.data_loader.questions)}")
263
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  def generate_sql(self, user_question: str) -> Tuple[str, str]:
265
- """生成SQL查詢 - 處理所有數據"""
266
  log_messages = [f"⏰ {get_current_time()} 開始處理"]
267
 
268
- if not user_question or user_question.strip() == "":
269
  return "請輸入您的問題。", "錯誤: 問題為空"
270
 
271
- # 1. 檢索最相似的問題(從所有4276條中)
272
  if self.data_loader.questions:
273
  hits = self.retrieval_system.retrieve_similar(user_question)
274
 
@@ -278,119 +325,92 @@ class CompleteTextToSQLSystem:
278
  corpus_id = best_hit['corpus_id']
279
  similar_question = self.data_loader.questions[corpus_id]
280
  original_sql = self.data_loader.sql_answers[corpus_id]
281
- sql_quality = self.data_loader.sql_quality[corpus_id]
282
 
283
- log_messages.append(f"🔍 檢索到: '{similar_question}'")
284
- log_messages.append(f"📊 相似度: {similarity_score:.3f}, 質量分數: {sql_quality:.1f}")
285
 
286
  if similarity_score > SIMILARITY_THRESHOLD:
287
- # 檢查並修復SQL(如果是空白的)
288
- validation = validate_sql(original_sql)
289
-
290
- if validation["empty"] or not validation["valid"]:
291
- log_messages.append(f"⚠️ 原始SQL需要修復: {', '.join(validation['issues'])}")
292
- log_messages.append("🛠️ 正在智能修復SQL...")
293
-
294
- repaired_sql = repair_empty_sql(original_sql, user_question, similar_question)
295
- log_messages.append("✅ 修復完成")
296
-
297
- return repaired_sql, "\n".join(log_messages)
298
- else:
299
- log_messages.append(f"✅ 相似度 > {SIMILARITY_THRESHOLD},使用預先SQL")
300
- return original_sql, "\n".join(log_messages)
301
  else:
302
- log_messages.append(f"ℹ️ 相似度 {similarity_score:.3f} 低於閾值 {SIMILARITY_THRESHOLD}")
303
-
304
  # 2. 如果檢索失敗或相似度不足,智能生成SQL
305
- log_messages.append("🤖 智能生成SQL...")
306
  analysis = analyze_question_type(user_question)
307
- intelligent_sql = generate_sql_from_question(user_question, analysis)
308
 
309
- log_messages.append(f"📋 問題分析: {analysis['type']}類型")
310
- log_messages.append("✅ 智能生成完成")
311
 
312
  return intelligent_sql, "\n".join(log_messages)
313
 
314
- # ==================== 其他類定義 ====================
315
- class LLMClient:
316
- def __init__(self, hf_token: str):
317
- self.hf_token = hf_token
318
-
319
- def call_llm_api(self, prompt: str) -> Optional[str]:
320
- headers = {"Authorization": f"Bearer {self.hf_token}"}
321
- payload = {"inputs": prompt, "parameters": {"max_new_tokens": 200, "temperature": 0.1}}
322
-
323
- for model_url in LLM_MODELS:
324
- try:
325
- response = requests.post(model_url, headers=headers, json=payload, timeout=15)
326
- if response.status_code == 200:
327
- result = response.json()
328
- if isinstance(result, list) and len(result) > 0:
329
- return result[0]['generated_text'].strip()
330
- except:
331
- continue
332
- return None
333
-
334
- class RetrievalSystem:
335
- def __init__(self):
336
- self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
337
- self.question_embeddings = None
338
-
339
- def compute_embeddings(self, questions: List[str]) -> None:
340
- if questions:
341
- print(f"正在為 {len(questions)} 個問題計算向量...")
342
- self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True, show_progress_bar=False)
343
- print("向量計算完成")
344
-
345
- def retrieve_similar(self, user_question: str, top_k: int = 5) -> List[Dict]:
346
- if self.question_embeddings is None or len(self.question_embeddings) == 0:
347
- return []
348
- try:
349
- question_embedding = self.embedder.encode(user_question, convert_to_tensor=True)
350
- hits = util.semantic_search(question_embedding, self.question_embeddings, top_k=top_k)
351
- return hits[0] if hits and hits[0] else []
352
- except Exception as e:
353
- print(f"檢索錯誤: {e}")
354
- return []
355
-
356
  # ==================== 初始化系統 ====================
357
- print("正在初始化完整數據Text-to-SQL系統...")
358
- text_to_sql_system = CompleteTextToSQLSystem(HF_TOKEN)
 
 
 
 
 
 
 
 
 
359
 
360
  # ==================== Gradio界面 ====================
361
- def process_query(user_question: str) -> Tuple[str, str]:
 
 
 
 
362
  sql_result, log_message = text_to_sql_system.generate_sql(user_question)
363
  return sql_result, "✅ SQL生成完成", log_message
364
 
365
- with gr.Blocks(title="智能Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
366
- gr.Markdown("# 🚀 智能 Text-to-SQL 系統")
367
- gr.Markdown("📊 完整模式: 讀取全部4276條數據")
368
 
369
  with gr.Row():
370
  question_input = gr.Textbox(
371
- label="📝 輸入問題",
372
- placeholder="例如:2023年每月完成多少份報告",
373
- lines=2,
374
  scale=4
375
  )
376
  submit_btn = gr.Button("🚀 生成SQL", variant="primary", scale=1)
377
 
378
- with gr.Row():
379
  sql_output = gr.Code(
380
- label="📊 生成的SQL",
381
  language="sql",
382
- lines=6
383
  )
384
-
385
- with gr.Row():
386
- debug_output = gr.Textbox(label="🔍 狀態", lines=2, interactive=False)
387
- log_output = gr.Textbox(label="📋 詳細日誌", lines=4, interactive=False)
388
-
 
 
 
 
 
 
 
 
 
 
389
  submit_btn.click(
390
  process_query,
391
  inputs=question_input,
392
- outputs=[sql_output, debug_output, log_output]
393
  )
394
 
395
  if __name__ == "__main__":
396
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
12
  import numpy as np
13
 
14
  # ==================== 配置區 ====================
15
+ HF_TOKEN = os.environ.get("HF_TOKEN", None) # 建議從環境變數讀取
16
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
17
+ SIMILARITY_THRESHOLD = 0.6
18
 
19
+ # 多個備用LLM模型 (注意:在當前邏輯中並未使用)
20
  LLM_MODELS = [
21
  "https://api-inference.huggingface.co/models/gpt2",
22
+ "https://api-inference.huggingface.co/models/distilgpt2",
23
  "https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
24
  ]
25
 
26
  print("=" * 60)
27
  print("🤖 智能 Text-to-SQL 系統啟動中...")
28
+ print(f"📊 模式: 讀取全部數據(來自 {DATASET_REPO_ID})")
29
  print("=" * 60)
30
 
31
+ # ==================== 獨立工具函數 (不依賴類別實例) ====================
32
  def get_current_time():
33
+ """獲取當前時間字串"""
34
  return datetime.now().strftime("%H:%M:%S")
35
 
36
  def validate_sql(sql_query: str) -> Dict:
37
+ """驗證SQL語句的語法和安全性"""
38
+ if not sql_query or not sql_query.strip():
39
  return {"valid": False, "issues": ["SQL語句為空"], "is_safe": False, "empty": True}
40
 
41
  sql_clean = sql_query.strip()
 
48
  # 檢查危險操作
49
  dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
50
  for keyword in dangerous_keywords:
51
+ if f" {keyword} " in f" {sql_upper} ":
52
  security_issues.append(f"危險操作: {keyword}")
53
 
54
  # 檢查基本語法
 
58
  if "FROM" not in sql_upper:
59
  security_issues.append("缺少FROM")
60
 
61
+ is_valid = not security_issues
62
+ is_safe = all('危險' not in issue for issue in security_issues)
63
+
64
  return {
65
+ "valid": is_valid,
66
  "issues": security_issues,
67
+ "is_safe": is_safe,
68
  "empty": False
69
  }
70
 
71
  def analyze_question_type(question: str) -> Dict:
72
+ """分析問題類型和關鍵詞"""
73
  question_lower = question.lower()
74
 
75
  analysis = {
 
83
 
84
  # 檢測關鍵詞
85
  keywords_sets = {
86
+ "sales": ["銷售", "業績", "金額", "收入", "sale", "revenue"],
87
+ "customer": ["客戶", "買家", "用戶", "customer", "client"],
88
+ "product": ["產品", "商品", "項目", "product", "item"],
89
+ "time": ["時間", "日期", "月份", "年", "月", "最近", "date", "month", "year"],
90
+ "report": ["報告", "完成", "份", "report", "complete"],
91
+ "count": ["多少", "幾個", "數量", "count", "how many"],
92
+ "comparison": ["比較", "vs", " versus", "對比", "相比"]
93
  }
94
 
95
  for category, keywords in keywords_sets.items():
96
+ if any(keyword in question_lower for keyword in keywords):
97
+ analysis["keywords"].append(category)
98
+
 
 
99
  # 特殊檢測
100
+ analysis["has_count"] = "count" in analysis["keywords"]
101
+ analysis["has_date"] = "time" in analysis["keywords"]
102
  analysis["has_group"] = any(word in question_lower for word in ["每", "各", "group", "每個"])
103
+ analysis["has_comparison"] = "comparison" in analysis["keywords"]
104
 
105
  # 確定主要類型
106
  if analysis["keywords"]:
 
108
 
109
  return analysis
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # ==================== 完整數據加載模塊 ====================
112
  class CompleteDataLoader:
113
  def __init__(self, hf_token: str):
 
124
  raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
125
 
126
  print("解析全部 messages 格式...")
127
+ total_count, empty_count, valid_count = 0, 0, 0
 
 
128
 
129
  for item in raw_dataset:
130
  try:
 
140
  sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
141
  if sql_match:
142
  sql_query = sql_match.group(1).strip()
143
+ sql_query = re.sub(r'^sql\s*', '', sql_query, flags=re.IGNORECASE)
144
  sql_query = re.sub(r'```sql|```', '', sql_query).strip()
145
  else:
146
  sql_query = assistant_content
 
159
  empty_count += 1
160
  if validation["valid"]:
161
  valid_count += 1
162
+ except Exception:
 
163
  continue
164
 
165
  print(f"數據加載完成: 總數 {total_count}, 有效 {valid_count}, 空白 {empty_count}")
 
187
  self.schema_data = {}
188
  return False
189
 
190
+ # ==================== 檢索系統 ====================
191
+ class RetrievalSystem:
192
+ def __init__(self):
193
+ try:
194
+ self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
195
+ self.question_embeddings = None
196
+ except Exception as e:
197
+ print(f"SentenceTransformer 模型加載���敗: {e}")
198
+ self.embedder = None
199
+
200
+ def compute_embeddings(self, questions: List[str]) -> None:
201
+ if self.embedder and questions:
202
+ print(f"正在為 {len(questions)} 個問題計算向量...")
203
+ self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
204
+ print("向量計算完成")
205
+
206
+ def retrieve_similar(self, user_question: str, top_k: int = 5) -> List[Dict]:
207
+ if self.embedder is None or self.question_embeddings is None or len(self.question_embeddings) == 0:
208
+ return []
209
+ try:
210
+ question_embedding = self.embedder.encode(user_question, convert_to_tensor=True)
211
+ hits = util.semantic_search(question_embedding, self.question_embeddings, top_k=top_k)
212
+ return hits[0] if hits and hits[0] else []
213
+ except Exception as e:
214
+ print(f"檢索錯誤: {e}")
215
+ return []
216
+
217
  # ==================== 主系統 ====================
218
  class CompleteTextToSQLSystem:
219
  def __init__(self, hf_token: str):
220
  self.hf_token = hf_token
221
  self.data_loader = CompleteDataLoader(hf_token)
222
  self.retrieval_system = RetrievalSystem()
 
223
  self.initialize_system()
224
 
225
  def initialize_system(self):
 
229
  self.data_loader.load_complete_dataset()
230
  self.data_loader.load_schema()
231
 
232
+ # 為所有問題計算向量
233
  if self.data_loader.questions:
234
  self.retrieval_system.compute_embeddings(self.data_loader.questions)
235
 
236
  print(f"系統初始化完成,載入問題總數: {len(self.data_loader.questions)}")
237
+
238
+ # ===== 輔助函數 (作為類別方法) =====
239
+ def get_available_tables(self) -> Dict:
240
+ """從schema中獲取所有可用的表和欄位"""
241
+ if not self.data_loader.schema_data:
242
+ return {}
243
+
244
+ tables = {}
245
+ for table_name, columns_list in self.data_loader.schema_data.items():
246
+ if isinstance(columns_list, list):
247
+ column_names = [col["name"] for col in columns_list if "name" in col]
248
+ tables[table_name] = column_names
249
+
250
+ return tables
251
+
252
+ def extract_number(self, text: str, default: int = 10) -> int:
253
+ """從文字中提取數字"""
254
+ numbers = re.findall(r'\d+', text)
255
+ return int(numbers[0]) if numbers else default
256
+
257
+ def generate_sql_from_question(self, question: str, analysis: Dict) -> str:
258
+ """根據問題分析和真實Schema生成智能SQL"""
259
+ question_lower = question.lower()
260
+ available_tables = self.get_available_tables().keys()
261
+
262
+ # 1. 每月/每日完成數量 - 使用 JobTimeline 相關表
263
+ if any(kw in question_lower for kw in ["每月", "每日", "昨天", "完成"]) and analysis["has_count"]:
264
+ group_match = re.search(r'([a-z]組)', question_lower)
265
+ if group_match:
266
+ group = group_match.group(1).replace('組', '').upper()
267
+ group_mapping = {'A': 'TA', 'B': 'TB', 'C': 'TC', 'D': 'TD'}
268
+ table_suffix = group_mapping.get(group, 'TA')
269
+ table_name = f"JobTimeline_{table_suffix}"
270
+
271
+ if "昨天" in question_lower:
272
+ return f"SELECT COUNT(*) as 完成數量 FROM {table_name} WHERE DATE(end_time) = DATE('now','-1 day');"
273
+ elif "每月" in question_lower:
274
+ year_match = re.search(r'(\d{4})年?', question_lower)
275
+ year = year_match.group(1) if year_match else datetime.now().strftime('%Y')
276
+ return f"""SELECT strftime('%Y-%m', end_time) as 月份, COUNT(*) as 完成數量 FROM {table_name} WHERE strftime('%Y', end_time) = '{year}' AND end_time IS NOT NULL GROUP BY strftime('%Y-%m', end_time) ORDER BY 月份;"""
277
+ return "SELECT strftime('%Y-%m', jt.end_time) as 月份, COUNT(*) as 完成數量 FROM JobTimeline jt WHERE jt.end_time IS NOT NULL GROUP BY strftime('%Y-%m', jt.end_time) ORDER BY 月份;"
278
+
279
+ # 2. 評級分析 - 使用 TSR53SampleDescription.OverallRating
280
+ elif any(kw in question_lower for kw in ["評級", "rating", "等級"]) and "TSR53SampleDescription" in available_tables:
281
+ if any(kw in question_lower for kw in ["分佈", "統計", "多少"]):
282
+ return "SELECT OverallRating as 評級, COUNT(*) as 數量, ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM TSR53SampleDescription), 2) as 百分比 FROM TSR53SampleDescription WHERE OverallRating IS NOT NULL GROUP BY OverallRating ORDER BY 數量 DESC;"
283
+ elif "fail" in question_lower or "失敗" in question_lower:
284
+ return "SELECT JobNo as 工作單號, ApplicantName as 申請方, OverallRating as 評級 FROM TSR53SampleDescription WHERE OverallRating = 'Fail' ORDER BY JobNo;"
285
+
286
+ # 3. 金額相���查詢 - 使用 TSR53Invoice
287
+ elif any(kw in question_lower for kw in ["金額", "總額", "收入", "invoice"]) and any(kw in question_lower for kw in ["最高", "最大", "top"]):
288
+ limit_num = self.extract_number(question_lower, default=10)
289
+ return f"""WITH JobTotalAmount AS (SELECT JobNo, SUM(LocalAmount) AS TotalAmount FROM (SELECT DISTINCT JobNo, InvoiceCreditNoteNo, LocalAmount FROM TSR53Invoice WHERE LocalAmount IS NOT NULL) GROUP BY JobNo) SELECT jta.JobNo as 工作單號, sd.ApplicantName as 申請方, jta.TotalAmount as 總金額 FROM JobTotalAmount jta JOIN TSR53SampleDescription sd ON sd.JobNo = jta.JobNo ORDER BY jta.TotalAmount DESC LIMIT {limit_num};"""
290
+
291
+ # 4. 公司/客戶相關查詢
292
+ elif any(kw in question_lower for kw in ["公司", "客戶", "申請方", "付款方"]):
293
+ if any(kw in question_lower for kw in ["最多", "top", "排名"]):
294
+ return "SELECT ApplicantName as 申請方名稱, COUNT(*) as 工作單數量 FROM TSR53SampleDescription WHERE ApplicantName IS NOT NULL GROUP BY ApplicantName ORDER BY 工作單數量 DESC LIMIT 10;"
295
+ return "SELECT ApplicantName as 申請方, InvoiceToName as 付款方, COUNT(*) as 工作單數量 FROM TSR53SampleDescription WHERE ApplicantName IS NOT NULL GROUP BY ApplicantName, InvoiceToName ORDER BY 工作單數量 DESC;"
296
+
297
+ # ... 其他規則可以繼續添加 ...
298
+
299
+ # 預設查詢 - 顯示基本工作單資訊
300
+ return "SELECT JobNo as 工作單號, ApplicantName as 申請方, InvoiceToName as 付款方, OverallRating as 評級 FROM TSR53SampleDescription LIMIT 20;"
301
+
302
+ def repair_empty_sql(self, original_sql: str, user_question: str, similar_question: str) -> str:
303
+ """修復空白或無效的SQL"""
304
+ validation = validate_sql(original_sql)
305
+ if not validation["valid"]:
306
+ analysis = analyze_question_type(user_question)
307
+ repaired_sql = self.generate_sql_from_question(user_question, analysis)
308
+ return f"-- 根據類似問題 '{similar_question}' (原SQL無效) 自動生成的查詢\n{repaired_sql}"
309
+ return original_sql
310
+
311
  def generate_sql(self, user_question: str) -> Tuple[str, str]:
312
+ """主流程:生成SQL查詢"""
313
  log_messages = [f"⏰ {get_current_time()} 開始處理"]
314
 
315
+ if not user_question or not user_question.strip():
316
  return "請輸入您的問題。", "錯誤: 問題為空"
317
 
318
+ # 1. 檢索最相似的問題
319
  if self.data_loader.questions:
320
  hits = self.retrieval_system.retrieve_similar(user_question)
321
 
 
325
  corpus_id = best_hit['corpus_id']
326
  similar_question = self.data_loader.questions[corpus_id]
327
  original_sql = self.data_loader.sql_answers[corpus_id]
 
328
 
329
+ log_messages.append(f"🔍 檢索到最相似問題: '{similar_question}'")
330
+ log_messages.append(f"📊 相似度: {similarity_score:.3f}")
331
 
332
  if similarity_score > SIMILARITY_THRESHOLD:
333
+ repaired_sql = self.repair_empty_sql(original_sql, user_question, similar_question)
334
+ log_messages.append(f"✅ 相似度高於閾值 {SIMILARITY_THRESHOLD},採用檢索結果。")
335
+ return repaired_sql, "\n".join(log_messages)
 
 
 
 
 
 
 
 
 
 
 
336
  else:
337
+ log_messages.append(f"ℹ️ 相似度低於閾值 {SIMILARITY_THRESHOLD},轉為智能生成���")
338
+
339
  # 2. 如果檢索失敗或相似度不足,智能生成SQL
340
+ log_messages.append("🤖 找不到高相似度結果,啟用智能生成規則...")
341
  analysis = analyze_question_type(user_question)
342
+ intelligent_sql = self.generate_sql_from_question(user_question, analysis)
343
 
344
+ log_messages.append(f"📋 問題分析: {analysis['type']} 類型, 關鍵詞: {analysis['keywords']}")
345
+ log_messages.append("✅ 智能生成完成。")
346
 
347
  return intelligent_sql, "\n".join(log_messages)
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  # ==================== 初始化系統 ====================
350
+ print("準備初始化 Text-to-SQL 系統...")
351
+ # 檢查 HF_TOKEN 是否存在
352
+ if HF_TOKEN is None:
353
+ print("\n" + "="*60)
354
+ print("⚠️ 警告: Hugging Face Token 未設置。")
355
+ print("請在環境變數中設定 HF_TOKEN 才能從私人數據集下載資料。")
356
+ print("="*60 + "\n")
357
+ # 這裡可以選擇退出或繼續,但下載會失敗
358
+ text_to_sql_system = None
359
+ else:
360
+ text_to_sql_system = CompleteTextToSQLSystem(HF_TOKEN)
361
 
362
  # ==================== Gradio界面 ====================
363
+ def process_query(user_question: str) -> Tuple[str, str, str]:
364
+ if text_to_sql_system is None:
365
+ error_msg = "系統因缺少 Hugging Face Token 而未成功初始化。"
366
+ return "系統未初始化", error_msg, error_msg
367
+
368
  sql_result, log_message = text_to_sql_system.generate_sql(user_question)
369
  return sql_result, "✅ SQL生成完成", log_message
370
 
371
+ with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
372
+ gr.Markdown("# 🚀 智慧 Text-to-SQL 系統")
373
+ gr.Markdown("📊 **模式**: 讀取雲端數據集並結合「檢索」與「規則生成」兩種模式。")
374
 
375
  with gr.Row():
376
  question_input = gr.Textbox(
377
+ label="📝 請在此輸入您的問題",
378
+ placeholder="例如:2023年每月完成多少份報告? 或 哪個客戶的訂單總金額最高?",
379
+ lines=3,
380
  scale=4
381
  )
382
  submit_btn = gr.Button("🚀 生成SQL", variant="primary", scale=1)
383
 
384
+ with gr.Accordion("🔍 結果與日誌", open=True):
385
  sql_output = gr.Code(
386
+ label="📊 生成的SQL查詢",
387
  language="sql",
388
+ lines=8
389
  )
390
+ status_output = gr.Textbox(label="🔍 執行狀態", interactive=False)
391
+ log_output = gr.Textbox(label="📋 詳細日誌", lines=5, interactive=False)
392
+
393
+ # 預設範例
394
+ gr.Examples(
395
+ examples=[
396
+ "昨天完成了多少個工作單?",
397
+ "A組每月完成數量是多少?",
398
+ "哪個申請方的失敗評級最多?",
399
+ "找出總金額最高的10筆訂單",
400
+ "統計所有評級的分佈"
401
+ ],
402
+ inputs=question_input
403
+ )
404
+
405
  submit_btn.click(
406
  process_query,
407
  inputs=question_input,
408
+ outputs=[sql_output, status_output, log_output]
409
  )
410
 
411
  if __name__ == "__main__":
412
+ print("Gradio 介面啟動中...")
413
+ if text_to_sql_system is None:
414
+ print("無法啟動 Gradio,因為系統初始化失敗。")
415
+ else:
416
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)