Paul720810 commited on
Commit
9943c6f
·
verified ·
1 Parent(s): cc745df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -125
app.py CHANGED
@@ -14,7 +14,7 @@ import numpy as np
14
  # ==================== 配置區 ====================
15
  HF_TOKEN = os.environ.get("HF_TOKEN", "您的_HuggingFace_Token")
16
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
17
- SIMILARITY_THRESHOLD = 0.70 # 降低閾值,因為很多數據有問題
18
 
19
  # 多個備用LLM模型
20
  LLM_MODELS = [
@@ -25,7 +25,7 @@ LLM_MODELS = [
25
 
26
  print("=" * 60)
27
  print("🤖 智能 Text-to-SQL 系統啟動中...")
28
- print("⚠️ 檢測到大量無效數據,啟用增強修復模式")
29
  print("=" * 60)
30
 
31
  # ==================== 增強工具函數 ====================
@@ -38,8 +38,8 @@ def validate_sql(sql_query: str) -> Dict:
38
  return {"valid": False, "issues": ["SQL語句為空"], "is_safe": False, "empty": True}
39
 
40
  sql_clean = sql_query.strip()
41
- if len(sql_clean) < 10: # 非常短的SQL可能無效
42
- return {"valid": False, "issues": ["SQL過短"], "is_safe": False, "empty": False}
43
 
44
  security_issues = []
45
  sql_upper = sql_clean.upper()
@@ -73,30 +73,32 @@ def analyze_question_type(question: str) -> Dict:
73
  "keywords": [],
74
  "has_count": False,
75
  "has_date": False,
76
- "has_group": False
 
77
  }
78
 
79
  # 檢測關鍵詞
80
  keywords_sets = {
81
- "sales": ["銷售", "業績", "金額", "收入", "sale", "revenue"],
82
- "customer": ["客戶", "買家", "用戶", "customer", "client"],
83
- "product": ["產品", "商品", "項目", "product", "item"],
84
- "time": ["時間", "日期", "月份", "年", "月", "最近", "date", "month", "year"],
85
- "report": ["報告", "完成", "份", "report", "complete"],
86
- "count": ["多少", "幾個", "數量", "count", "how many"]
 
87
  }
88
 
89
  for category, keywords in keywords_sets.items():
90
  for keyword in keywords:
91
  if keyword in question_lower:
92
- analysis["keywords"].append(category)
93
  if category not in analysis["keywords"]:
94
  analysis["keywords"].append(category)
95
 
96
  # 特殊檢測
97
  analysis["has_count"] = any(kw in question_lower for kw in keywords_sets["count"])
98
  analysis["has_date"] = any(kw in question_lower for kw in keywords_sets["time"])
99
- analysis["has_group"] = "每" in question_lower or "各" in question_lower or "group" in question_lower
 
100
 
101
  # 確定主要類型
102
  if analysis["keywords"]:
@@ -104,71 +106,78 @@ def analyze_question_type(question: str) -> Dict:
104
 
105
  return analysis
106
 
107
- def generate_intelligent_sql(question: str, analysis: Dict) -> str:
108
  """根據問題分析生成智能SQL"""
 
109
  question_type = analysis["type"]
110
- has_count = analysis["has_count"]
111
- has_date = analysis["has_date"]
112
- has_group = analysis["has_group"]
113
-
114
- # 根據問題類型生成相應的SQL
115
- if question_type == "sales":
116
- if has_count and has_group and has_date:
117
- return "SELECT strftime('%Y-%m', sale_date) as month, COUNT(*) as sales_count, SUM(amount) as total_sales FROM sales GROUP BY month ORDER BY month;"
118
- elif has_count:
119
- return "SELECT product_name, COUNT(*) as sale_count FROM sales GROUP BY product_name ORDER BY sale_count DESC LIMIT 10;"
120
- else:
121
- return "SELECT product_name, SUM(amount) as total_sales FROM sales GROUP BY product_name ORDER BY total_sales DESC LIMIT 10;"
122
-
123
- elif question_type == "customer":
124
- if has_count and has_group:
125
- return "SELECT customer_name, COUNT(*) as order_count, SUM(amount) as total_spent FROM orders GROUP BY customer_name ORDER BY total_spent DESC;"
126
- else:
127
- return "SELECT customer_name, email, join_date FROM customers ORDER BY join_date DESC LIMIT 10;"
128
-
129
- elif question_type == "product":
130
- if has_count:
131
- return "SELECT category, COUNT(*) as product_count FROM products GROUP BY category ORDER BY product_count DESC;"
132
- else:
133
- return "SELECT product_name, price, stock_quantity FROM products WHERE stock_quantity > 0 ORDER BY price DESC LIMIT 10;"
134
-
135
- elif question_type == "report" or question_type == "time":
136
- if has_count and has_group and has_date:
137
- return "SELECT strftime('%Y-%m', report_date) as month, COUNT(*) as report_count FROM reports GROUP BY month ORDER BY month;"
138
- elif has_date:
139
- return "SELECT report_id, report_name, report_date FROM reports ORDER BY report_date DESC LIMIT 10;"
140
- else:
141
- return "SELECT report_type, COUNT(*) as count FROM reports GROUP BY report_type ORDER BY count DESC;"
142
-
143
- # 默認SQL
144
- if has_count and has_group:
145
- return "SELECT category, COUNT(*) as item_count FROM items GROUP BY category ORDER BY item_count DESC;"
146
- elif has_count:
147
- return "SELECT COUNT(*) as total_count FROM records;"
148
  else:
149
- return "SELECT * FROM data_table LIMIT 10;"
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- # ==================== 智能數據加載模塊 ====================
152
- class SmartDataLoader:
153
  def __init__(self, hf_token: str):
154
  self.hf_token = hf_token
155
  self.questions = []
156
  self.sql_answers = []
157
- self.valid_indices = [] # 記錄有效數據的索引
158
  self.schema_data = {}
159
 
160
- def load_and_clean_dataset(self) -> bool:
161
- """加載並清理數據集"""
162
  try:
163
- print(f"[{get_current_time()}] 加載數據集 '{DATASET_REPO_ID}'...")
164
  raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
165
 
166
- print("解析 messages 格式並過濾無效數據...")
167
- valid_count = 0
168
  empty_count = 0
169
- invalid_count = 0
170
 
171
- for i, item in enumerate(raw_dataset):
172
  try:
173
  if 'messages' in item and len(item['messages']) >= 2:
174
  user_content = item['messages'][0]['content']
@@ -187,50 +196,31 @@ class SmartDataLoader:
187
  else:
188
  sql_query = assistant_content
189
 
190
- # 驗證SQL - 只保留真正有效的數據
 
 
 
 
191
  validation = validate_sql(sql_query)
 
 
192
 
 
 
 
193
  if validation["valid"]:
194
- self.questions.append(question)
195
- self.sql_answers.append(sql_query)
196
- self.valid_indices.append(i)
197
  valid_count += 1
198
- elif validation["empty"]:
199
- empty_count += 1
200
- else:
201
- invalid_count += 1
202
 
203
  except Exception as e:
204
  continue
205
 
206
- print(f"數據清理完成: {valid_count} 有效, {empty_count} 空, {invalid_count} 無效")
207
-
208
- # 如果有效數據太少,添加一些備用問題
209
- if valid_count < 100:
210
- print("有效數據過少,添加備用問題...")
211
- self.add_backup_examples()
212
-
213
  return True
214
 
215
  except Exception as e:
216
  print(f"數據集加載失敗: {e}")
217
- self.add_backup_examples()
218
  return False
219
 
220
- def add_backup_examples(self):
221
- """添加備用範例"""
222
- backup_data = [
223
- {"question": "查詢銷售額最高的產品", "sql": "SELECT product_name, SUM(sales_amount) as total_sales FROM sales GROUP BY product_name ORDER BY total_sales DESC LIMIT 10;"},
224
- {"question": "顯示最近30天的訂單", "sql": "SELECT * FROM orders WHERE order_date >= date('now', '-30 days') ORDER BY order_date DESC;"},
225
- {"question": "統計每個客戶的訂單數量", "sql": "SELECT customer_name, COUNT(*) as order_count FROM orders GROUP BY customer_name ORDER BY order_count DESC;"},
226
- {"question": "2023年每月銷售額", "sql": "SELECT strftime('%Y-%m', sale_date) as month, SUM(amount) as monthly_sales FROM sales WHERE strftime('%Y', sale_date) = '2023' GROUP BY month ORDER BY month;"},
227
- {"question": "庫存不足的商品", "sql": "SELECT product_name, stock_quantity FROM products WHERE stock_quantity < 10 ORDER BY stock_quantity ASC;"}
228
- ]
229
-
230
- for data in backup_data:
231
- self.questions.append(data["question"])
232
- self.sql_answers.append(data["sql"])
233
-
234
  def load_schema(self) -> bool:
235
  """加載數據庫Schema"""
236
  try:
@@ -250,60 +240,73 @@ class SmartDataLoader:
250
  return False
251
 
252
  # ==================== 主系統 ====================
253
- class EnhancedTextToSQLSystem:
254
  def __init__(self, hf_token: str):
255
  self.hf_token = hf_token
256
- self.data_loader = SmartDataLoader(hf_token)
257
  self.retrieval_system = RetrievalSystem()
258
 
259
  self.initialize_system()
260
 
261
  def initialize_system(self):
262
  """初始化系統組件"""
263
- print("初始化系統組件...")
264
 
265
- self.data_loader.load_and_clean_dataset()
266
  self.data_loader.load_schema()
267
 
268
- # 只為有效數據計算向量
269
  if self.data_loader.questions:
270
  self.retrieval_system.compute_embeddings(self.data_loader.questions)
271
 
272
- print(f"系統初始化完成,可用有效問題: {len(self.data_loader.questions)}")
273
 
274
  def generate_sql(self, user_question: str) -> Tuple[str, str]:
275
- """生成SQL查詢"""
276
  log_messages = [f"⏰ {get_current_time()} 開始處理"]
277
 
278
  if not user_question or user_question.strip() == "":
279
  return "請輸入您的問題。", "錯誤: 問題為空"
280
 
281
- # 分析問題
282
- question_analysis = analyze_question_type(user_question)
283
- log_messages.append(f"🔍 問題分析: {question_analysis['type']}類型")
284
-
285
- # 1. 嘗試檢索相似問題(只在有有效數據時)
286
  if self.data_loader.questions:
287
  hits = self.retrieval_system.retrieve_similar(user_question)
288
 
289
  if hits:
290
  best_hit = hits[0]
291
  similarity_score = best_hit['score']
292
- similar_question = self.data_loader.questions[best_hit['corpus_id']]
293
- original_sql = self.data_loader.sql_answers[best_hit['corpus_id']]
 
 
294
 
295
- log_messages.append(f"📋 檢索到: '{similar_question}'")
296
- log_messages.append(f"📊 相似度: {similarity_score:.3f}")
297
 
298
  if similarity_score > SIMILARITY_THRESHOLD:
299
- log_messages.append(f"✅ 相似度 > {SIMILARITY_THRESHOLD},使用預先SQL")
300
- return original_sql, "\n".join(log_messages)
 
 
 
 
 
 
 
 
 
 
 
 
301
  else:
302
- log_messages.append(f"ℹ️ 相似度不足,嘗試其他方法")
303
 
304
- # 2. 智能生成SQL
305
  log_messages.append("🤖 智能生成SQL...")
306
- intelligent_sql = generate_intelligent_sql(user_question, question_analysis)
 
 
 
307
  log_messages.append("✅ 智能生成完成")
308
 
309
  return intelligent_sql, "\n".join(log_messages)
@@ -335,21 +338,24 @@ class RetrievalSystem:
335
 
336
  def compute_embeddings(self, questions: List[str]) -> None:
337
  if questions:
338
- self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True)
 
 
339
 
340
- def retrieve_similar(self, user_question: str, top_k: int = 3) -> List[Dict]:
341
  if self.question_embeddings is None or len(self.question_embeddings) == 0:
342
  return []
343
  try:
344
  question_embedding = self.embedder.encode(user_question, convert_to_tensor=True)
345
  hits = util.semantic_search(question_embedding, self.question_embeddings, top_k=top_k)
346
  return hits[0] if hits and hits[0] else []
347
- except:
 
348
  return []
349
 
350
  # ==================== 初始化系統 ====================
351
- print("正在初始化增強版Text-to-SQL系統...")
352
- text_to_sql_system = EnhancedTextToSQLSystem(HF_TOKEN)
353
 
354
  # ==================== Gradio界面 ====================
355
  def process_query(user_question: str) -> Tuple[str, str]:
@@ -358,17 +364,16 @@ def process_query(user_question: str) -> Tuple[str, str]:
358
 
359
  with gr.Blocks(title="智能Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
360
  gr.Markdown("# 🚀 ��能 Text-to-SQL 系統")
361
- gr.Markdown("💡 針對大量無效數據優化的增強版本")
362
 
363
  with gr.Row():
364
  question_input = gr.Textbox(
365
  label="📝 輸入問題",
366
- placeholder="例如:查詢2023年每月報告數量",
367
- lines=2
 
368
  )
369
-
370
- with gr.Row():
371
- submit_btn = gr.Button("🚀 生成SQL", variant="primary")
372
 
373
  with gr.Row():
374
  sql_output = gr.Code(
 
14
  # ==================== 配置區 ====================
15
  HF_TOKEN = os.environ.get("HF_TOKEN", "您的_HuggingFace_Token")
16
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
17
+ SIMILARITY_THRESHOLD = 0.75
18
 
19
  # 多個備用LLM模型
20
  LLM_MODELS = [
 
25
 
26
  print("=" * 60)
27
  print("🤖 智能 Text-to-SQL 系統啟動中...")
28
+ print("📊 模式: 讀取全部4276條數據(包含空白SQL)")
29
  print("=" * 60)
30
 
31
  # ==================== 增強工具函數 ====================
 
38
  return {"valid": False, "issues": ["SQL語句為空"], "is_safe": False, "empty": True}
39
 
40
  sql_clean = sql_query.strip()
41
+ if len(sql_clean) < 5:
42
+ return {"valid": False, "issues": ["SQL過短"], "is_safe": False, "empty": True}
43
 
44
  security_issues = []
45
  sql_upper = sql_clean.upper()
 
73
  "keywords": [],
74
  "has_count": False,
75
  "has_date": False,
76
+ "has_group": False,
77
+ "has_comparison": False
78
  }
79
 
80
  # 檢測關鍵詞
81
  keywords_sets = {
82
+ "sales": ["銷售", "業績", "金額", "收入", "sale", "revenue", "金額"],
83
+ "customer": ["客戶", "買家", "用戶", "customer", "client", "買家"],
84
+ "product": ["產品", "商品", "項目", "product", "item", "產品"],
85
+ "time": ["時間", "日期", "月份", "年", "月", "最近", "date", "month", "year", "時間"],
86
+ "report": ["報告", "完成", "份", "report", "complete", "報告"],
87
+ "count": ["多少", "幾個", "數量", "count", "how many", "多少"],
88
+ "comparison": ["比較", "vs", " versus", "對比", "相比", "比較"]
89
  }
90
 
91
  for category, keywords in keywords_sets.items():
92
  for keyword in keywords:
93
  if keyword in question_lower:
 
94
  if category not in analysis["keywords"]:
95
  analysis["keywords"].append(category)
96
 
97
  # 特殊檢測
98
  analysis["has_count"] = any(kw in question_lower for kw in keywords_sets["count"])
99
  analysis["has_date"] = any(kw in question_lower for kw in keywords_sets["time"])
100
+ analysis["has_group"] = any(word in question_lower for word in ["每", "各", "group", "每個"])
101
+ analysis["has_comparison"] = any(kw in question_lower for kw in keywords_sets["comparison"])
102
 
103
  # 確定主要類型
104
  if analysis["keywords"]:
 
106
 
107
  return analysis
108
 
109
+ def generate_sql_from_question(question: str, analysis: Dict) -> str:
110
  """根據問題分析生成智能SQL"""
111
+ question_lower = question.lower()
112
  question_type = analysis["type"]
113
+
114
+ # 針對常見問題模式的SQL生成
115
+ if "每月" in question_lower and ("完成" in question_lower or "報告" in question_lower):
116
+ year_match = re.search(r'(\d{4})年', question_lower)
117
+ year = year_match.group(1) if year_match else "2023"
118
+ return f"SELECT strftime('%Y-%m', completion_date) as month, COUNT(*) as report_count FROM reports WHERE strftime('%Y', completion_date) = '{year}' GROUP BY month ORDER BY month;"
119
+
120
+ elif "銷售" in question_lower and ("最高" in question_lower or "最好" in question_lower):
121
+ return "SELECT product_name, SUM(sales_amount) as total_sales FROM sales GROUP BY product_name ORDER BY total_sales DESC LIMIT 10;"
122
+
123
+ elif "客戶" in question_lower and ("訂單" in question_lower or "購買" in question_lower):
124
+ return "SELECT customer_name, COUNT(*) as order_count, SUM(order_amount) as total_spent FROM orders GROUP BY customer_name ORDER BY total_spent DESC;"
125
+
126
+ elif "比較" in question_lower and ("" in question_lower or "年份" in question_lower):
127
+ return "SELECT strftime('%Y', order_date) as year, COUNT(*) as order_count, SUM(order_amount) as yearly_revenue FROM orders GROUP BY year ORDER BY year;"
128
+
129
+ elif "庫存" in question_lower and ("不足" in question_lower or "缺少" in question_lower):
130
+ return "SELECT product_name, stock_quantity FROM products WHERE stock_quantity < 10 ORDER BY stock_quantity ASC;"
131
+
132
+ # 根據分析結果生成通用SQL
133
+ if analysis["has_count"] and analysis["has_group"] and analysis["has_date"]:
134
+ return "SELECT strftime('%Y-%m', date_column) as period, COUNT(*) as item_count FROM appropriate_table GROUP BY period ORDER BY period;"
135
+
136
+ elif analysis["has_count"] and analysis["has_group"]:
137
+ return "SELECT category_column, COUNT(*) as count FROM appropriate_table GROUP BY category_column ORDER BY count DESC;"
138
+
139
+ elif analysis["has_count"]:
140
+ return "SELECT COUNT(*) as total_count FROM appropriate_table;"
141
+
142
+ elif analysis["has_group"]:
143
+ return "SELECT group_column, AVG(value_column) as average_value FROM appropriate_table GROUP BY group_column;"
144
+
 
 
 
 
 
 
145
  else:
146
+ return "SELECT * FROM appropriate_table LIMIT 10;"
147
+
148
+ def repair_empty_sql(original_sql: str, user_question: str, similar_question: str) -> str:
149
+ """修復空白SQL"""
150
+ if not original_sql or original_sql.strip() == "":
151
+ # 分析問題並生成合適的SQL
152
+ analysis = analyze_question_type(user_question)
153
+ repaired_sql = generate_sql_from_question(user_question, analysis)
154
+
155
+ # 添加註釋說明這是修復的SQL
156
+ return f"-- 根據類似問題 '{similar_question}' 修復生成的SQL\n{repaired_sql}"
157
+
158
+ return original_sql
159
 
160
+ # ==================== 完整數據加載模塊 ====================
161
+ class CompleteDataLoader:
162
  def __init__(self, hf_token: str):
163
  self.hf_token = hf_token
164
  self.questions = []
165
  self.sql_answers = []
166
+ self.sql_quality = [] # 記錄每個SQL的質量評分
167
  self.schema_data = {}
168
 
169
+ def load_complete_dataset(self) -> bool:
170
+ """加載完整數據集(包括空白SQL)"""
171
  try:
172
+ print(f"[{get_current_time()}] 正在加載完整數據集 '{DATASET_REPO_ID}'...")
173
  raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
174
 
175
+ print("解析全部 messages 格式...")
176
+ total_count = 0
177
  empty_count = 0
178
+ valid_count = 0
179
 
180
+ for item in raw_dataset:
181
  try:
182
  if 'messages' in item and len(item['messages']) >= 2:
183
  user_content = item['messages'][0]['content']
 
196
  else:
197
  sql_query = assistant_content
198
 
199
+ # 保存所有數據
200
+ self.questions.append(question)
201
+ self.sql_answers.append(sql_query)
202
+
203
+ # 評估SQL質量
204
  validation = validate_sql(sql_query)
205
+ quality_score = 1.0 if validation["valid"] else 0.3
206
+ self.sql_quality.append(quality_score)
207
 
208
+ total_count += 1
209
+ if validation["empty"]:
210
+ empty_count += 1
211
  if validation["valid"]:
 
 
 
212
  valid_count += 1
 
 
 
 
213
 
214
  except Exception as e:
215
  continue
216
 
217
+ print(f"數據加載完成: 總數 {total_count}, 有效 {valid_count}, 空白 {empty_count}")
 
 
 
 
 
 
218
  return True
219
 
220
  except Exception as e:
221
  print(f"數據集加載失敗: {e}")
 
222
  return False
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  def load_schema(self) -> bool:
225
  """加載數據庫Schema"""
226
  try:
 
240
  return False
241
 
242
  # ==================== 主系統 ====================
243
+ class CompleteTextToSQLSystem:
244
  def __init__(self, hf_token: str):
245
  self.hf_token = hf_token
246
+ self.data_loader = CompleteDataLoader(hf_token)
247
  self.retrieval_system = RetrievalSystem()
248
 
249
  self.initialize_system()
250
 
251
  def initialize_system(self):
252
  """初始化系統組件"""
253
+ print("正在初始化完整數據系統...")
254
 
255
+ self.data_loader.load_complete_dataset()
256
  self.data_loader.load_schema()
257
 
258
+ # 為所有問題計算向量(包括空白SQL的)
259
  if self.data_loader.questions:
260
  self.retrieval_system.compute_embeddings(self.data_loader.questions)
261
 
262
+ print(f"系統初始化完成,載入問題總數: {len(self.data_loader.questions)}")
263
 
264
  def generate_sql(self, user_question: str) -> Tuple[str, str]:
265
+ """生成SQL查詢 - 處理所有數據"""
266
  log_messages = [f"⏰ {get_current_time()} 開始處理"]
267
 
268
  if not user_question or user_question.strip() == "":
269
  return "請輸入您的問題。", "錯誤: 問題為空"
270
 
271
+ # 1. 檢索最相似的問題(從所有4276條中)
 
 
 
 
272
  if self.data_loader.questions:
273
  hits = self.retrieval_system.retrieve_similar(user_question)
274
 
275
  if hits:
276
  best_hit = hits[0]
277
  similarity_score = best_hit['score']
278
+ corpus_id = best_hit['corpus_id']
279
+ similar_question = self.data_loader.questions[corpus_id]
280
+ original_sql = self.data_loader.sql_answers[corpus_id]
281
+ sql_quality = self.data_loader.sql_quality[corpus_id]
282
 
283
+ log_messages.append(f"🔍 檢索到: '{similar_question}'")
284
+ log_messages.append(f"📊 相似度: {similarity_score:.3f}, 質量分數: {sql_quality:.1f}")
285
 
286
  if similarity_score > SIMILARITY_THRESHOLD:
287
+ # 檢查並修復SQL(如果是空白的)
288
+ validation = validate_sql(original_sql)
289
+
290
+ if validation["empty"] or not validation["valid"]:
291
+ log_messages.append(f"⚠️ 原始SQL需要修復: {', '.join(validation['issues'])}")
292
+ log_messages.append("🛠️ 正在智能修復SQL...")
293
+
294
+ repaired_sql = repair_empty_sql(original_sql, user_question, similar_question)
295
+ log_messages.append("✅ 修復完成")
296
+
297
+ return repaired_sql, "\n".join(log_messages)
298
+ else:
299
+ log_messages.append(f"✅ 相似度 > {SIMILARITY_THRESHOLD},使用預先SQL")
300
+ return original_sql, "\n".join(log_messages)
301
  else:
302
+ log_messages.append(f"ℹ️ 相似度 {similarity_score:.3f} 低於閾值 {SIMILARITY_THRESHOLD}")
303
 
304
+ # 2. 如果檢索失敗或相似度不足,智能生成SQL
305
  log_messages.append("🤖 智能生成SQL...")
306
+ analysis = analyze_question_type(user_question)
307
+ intelligent_sql = generate_sql_from_question(user_question, analysis)
308
+
309
+ log_messages.append(f"📋 問題分析: {analysis['type']}類型")
310
  log_messages.append("✅ 智能生成完成")
311
 
312
  return intelligent_sql, "\n".join(log_messages)
 
338
 
339
  def compute_embeddings(self, questions: List[str]) -> None:
340
  if questions:
341
+ print(f"正在為 {len(questions)} 個問題計算向量...")
342
+ self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True, show_progress_bar=False)
343
+ print("向量計算完成")
344
 
345
+ def retrieve_similar(self, user_question: str, top_k: int = 5) -> List[Dict]:
346
  if self.question_embeddings is None or len(self.question_embeddings) == 0:
347
  return []
348
  try:
349
  question_embedding = self.embedder.encode(user_question, convert_to_tensor=True)
350
  hits = util.semantic_search(question_embedding, self.question_embeddings, top_k=top_k)
351
  return hits[0] if hits and hits[0] else []
352
+ except Exception as e:
353
+ print(f"檢索錯誤: {e}")
354
  return []
355
 
356
  # ==================== 初始化系統 ====================
357
+ print("正在初始化完整數據Text-to-SQL系統...")
358
+ text_to_sql_system = CompleteTextToSQLSystem(HF_TOKEN)
359
 
360
  # ==================== Gradio界面 ====================
361
  def process_query(user_question: str) -> Tuple[str, str]:
 
364
 
365
  with gr.Blocks(title="智能Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
366
  gr.Markdown("# 🚀 ��能 Text-to-SQL 系統")
367
+ gr.Markdown("📊 完整模式: 讀取全部4276條數據")
368
 
369
  with gr.Row():
370
  question_input = gr.Textbox(
371
  label="📝 輸入問題",
372
+ placeholder="例如:2023年每月完成多少份報告",
373
+ lines=2,
374
+ scale=4
375
  )
376
+ submit_btn = gr.Button("🚀 生成SQL", variant="primary", scale=1)
 
 
377
 
378
  with gr.Row():
379
  sql_output = gr.Code(