Paul720810 commited on
Commit
b5ff516
·
verified ·
1 Parent(s): d3ec0ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +407 -187
app.py CHANGED
@@ -2,260 +2,480 @@ import gradio as gr
2
  import requests
3
  import json
4
  import os
 
 
 
 
5
  from datasets import load_dataset
6
  from sentence_transformers import SentenceTransformer, util
7
  import torch
8
  from huggingface_hub import hf_hub_download
9
- import re
10
 
11
- # --- 配置區 ---
12
- HF_TOKEN = os.environ.get("HF_TOKEN")
13
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
14
- SIMILARITY_THRESHOLD = 0.75 # 進一步降低閾值
15
 
16
- # 多個備用模型(保證至少有一個可用)
17
  LLM_MODELS = [
18
- "https://api-inference.huggingface.co/models/gpt2", # 最基礎的模型,保證可用
19
- "https://api-inference.huggingface.co/models/distilgpt2",
20
  "https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
21
  ]
22
 
23
- print("--- [1/5] 開始初始化應用 ---")
 
 
 
 
 
 
 
 
 
24
 
25
- # --- 1. 載入知識庫 ---
26
- questions = []
27
- sql_answers = []
28
- schema_data = {}
29
 
30
- try:
31
- print(f"--- [2/5] 正在從 '{DATASET_REPO_ID}' 載入知識庫... ---")
32
- raw_dataset = load_dataset(DATASET_REPO_ID, token=HF_TOKEN)['train']
 
 
 
 
 
 
 
33
 
34
- # 解析 messages 格式
35
- print("--- > 解析 messages 格式...")
 
 
 
36
 
37
- for item in raw_dataset:
38
- try:
39
- if 'messages' in item and len(item['messages']) >= 2:
40
- user_content = item['messages'][0]['content']
41
- assistant_content = item['messages'][1]['content']
42
-
43
- # 提取問題
44
- question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
45
- question = question_match.group(1).strip() if question_match else user_content
46
-
47
- # 提取SQL
48
- sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
49
- if sql_match:
50
- sql_query = sql_match.group(1).strip()
51
- sql_query = re.sub(r'^sql\s*', '', sql_query)
52
- sql_query = re.sub(r'```sql|```', '', sql_query).strip()
53
- else:
54
- sql_query = assistant_content
55
-
56
- questions.append(question)
57
- sql_answers.append(sql_query)
58
-
59
- except Exception as e:
60
- continue
61
 
62
- print(f"--- > 成功解析 {len(questions)} 條問答範例 ---")
 
63
 
64
- # 載入Schema
65
- try:
66
- schema_file_path = hf_hub_download(
67
- repo_id=DATASET_REPO_ID,
68
- filename="sqlite_schema_FULL.json",
69
- repo_type='dataset',
70
- token=HF_TOKEN
71
- )
72
- with open(schema_file_path, 'r', encoding='utf-8') as f:
73
- schema_data = json.load(f)
74
- except Exception as e:
75
- print(f"警告: 無法載入Schema文件: {e}")
76
-
77
- except Exception as e:
78
- print(f"錯誤: 載入數據集失敗: {e}")
79
- questions = ["示例問題"]
80
- sql_answers = ["SELECT '系統就緒' AS status;"]
81
-
82
- # --- 2. 初始化檢索模型 ---
83
- print("--- [3/5] 正在載入句向量模型... ---")
84
- embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
85
-
86
- # 計算問題向量
87
- if questions:
88
- print(f"--- [4/5] 正在為 {len(questions)} 個問題計算向量... ---")
89
- question_embeddings = embedder.encode(questions, convert_to_tensor=True, show_progress_bar=False) # 關閉進度條
90
- print("--- > 向量計算完���! ---")
91
- else:
92
- print("--- [4/5] 警告:沒有可用的問題 ---")
93
- question_embeddings = torch.Tensor([])
94
 
95
- # --- 3. 構建DDL ---
96
- def build_schema_context(schema_dict):
97
- if not schema_dict:
98
- return "/* 無Schema信息 */"
99
 
100
- context = "/* 數據庫表結構 */\n"
101
- for table_name, columns in schema_dict.items():
102
- if isinstance(columns, list):
103
- context += f"\n-- 表: {table_name}\n"
104
- for col in columns:
105
- col_name = col.get('name', 'unknown')
106
- col_type = col.get('type', 'TEXT')
107
- col_desc = col.get('description', '')
108
- context += f"-- {col_name} ({col_type}) - {col_desc}\n"
109
- return context
110
-
111
- SCHEMA_CONTEXT = build_schema_context(schema_data)
 
 
 
 
 
 
 
112
 
113
- # --- 4. LLM調用函數(多模型備用)---
114
- def call_llm_api(prompt, model_urls=LLM_MODELS):
115
- headers = {"Authorization": f"Bearer {HF_TOKEN}"}
116
- payload = {
117
- "inputs": prompt,
118
- "parameters": {
119
- "max_new_tokens": 150,
120
- "temperature": 0.1,
121
- "do_sample": False
122
- }
123
- }
124
-
125
- # 嘗試所有備用模型
126
- for model_url in model_urls:
127
  try:
128
- response = requests.post(model_url, headers=headers, json=payload, timeout=15)
 
129
 
130
- if response.status_code == 200:
131
- result = response.json()
132
- if isinstance(result, list) and len(result) > 0:
133
- return result[0]['generated_text'].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  except Exception as e:
136
- print(f"模型 {model_url} 失敗: {e}")
137
- continue
 
138
 
139
- return None # 所有模型都失敗
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- # --- 5. 核心邏輯 ---
142
- def get_sql_query(user_question: str):
143
- if not user_question:
144
- return "請輸入您的問題。", "日誌:用戶未輸入問題。"
145
 
146
- log_messages = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- # 1. 首先嘗試檢索
149
- if len(questions) > 0 and len(question_embeddings) > 0:
 
 
 
 
 
 
 
 
 
 
 
 
150
  try:
151
- question_embedding = embedder.encode(user_question, convert_to_tensor=True)
152
- hits = util.semantic_search(question_embedding, question_embeddings, top_k=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- if hits and hits[0]:
155
- best_hit = hits[0][0]
156
  similarity_score = best_hit['score']
157
- similar_question = questions[best_hit['corpus_id']]
158
 
159
- log_messages.append(f"檢索到相似問題: '{similar_question}' (相似度: {similarity_score:.3f})")
 
160
 
161
  if similarity_score > SIMILARITY_THRESHOLD:
162
- sql_result = sql_answers[best_hit['corpus_id']]
163
- log_messages.append(f"相似度 > {SIMILARITY_THRESHOLD},直接返回")
 
 
 
 
 
 
164
  return sql_result, "\n".join(log_messages)
165
  else:
166
- log_messages.append(f"相似度低於閾值 {SIMILARITY_THRESHOLD}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  else:
168
- log_messages.append("檢索失敗:找不到相似問題")
169
 
170
- except Exception as e:
171
- log_messages.append(f"檢索過程出錯: {e}")
172
-
173
- # 2. 檢索失敗或相似度低,嘗試LLM
174
- log_messages.append("嘗試LLM生成...")
 
175
 
176
- # 構建簡單提示詞
177
- prompt = f"""請為這個問題生成SQL查詢:
 
 
 
178
 
179
- 數據庫結構:
180
- {SCHEMA_CONTEXT}
181
 
182
- 問題:{user_question}
 
 
 
183
 
184
- SQL"""
185
 
186
- generated_sql = call_llm_api(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
- if generated_sql:
189
- # 清理輸出
190
- generated_sql = re.sub(r'^```sql|```$', '', generated_sql).strip()
191
- log_messages.append("LLM生成成功!")
192
- return generated_sql, "\n".join(log_messages)
193
  else:
194
- # 3. LLM也失敗,提供智能備用答案
195
- log_messages.append("所有LLM模型都失敗,提供備用答案")
196
 
197
- # 基於問題內容提供有意義的備用SQL
198
- if any(keyword in user_question.lower() for keyword in ['銷售', '業績', '金額']):
199
- backup_sql = "SELECT product_name, SUM(sales_amount) as total_sales FROM sales GROUP BY product_name ORDER BY total_sales DESC LIMIT 10;"
200
- elif any(keyword in user_question.lower() for keyword in ['客戶', '買家', '用戶']):
201
- backup_sql = "SELECT customer_name, COUNT(*) as order_count FROM orders GROUP BY customer_name ORDER BY order_count DESC;"
202
- elif any(keyword in user_question.lower() for keyword in ['時間', '日期', '最近']):
203
- backup_sql = "SELECT DATE(order_date) as day, COUNT(*) as orders FROM orders WHERE order_date >= DATE('now', '-7 days') GROUP BY day ORDER BY day DESC;"
204
- else:
205
- backup_sql = "SELECT '請重試或聯繫管理員' AS status;"
206
 
207
- return backup_sql, "\n".join(log_messages)
208
-
209
- # --- 6. 創建界面 ---
210
- print("--- [5/5] 正在創建 Web 界面... ---")
 
 
211
 
212
- with gr.Blocks(title="智能Text-to-SQL系統") as demo:
213
- gr.Markdown("# 🚀 智能 Text-to-SQL 系統")
214
- gr.Markdown("輸入自然語言問題,自動生成SQL查詢")
 
 
 
 
 
 
 
 
215
 
216
- with gr.Row():
217
- question_input = gr.Textbox(
218
- label="您的問題",
219
- placeholder="例如:查詢2024年的銷售數據",
220
- lines=2
221
- )
222
 
223
  with gr.Row():
224
- submit_btn = gr.Button("生成SQL", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  with gr.Row():
227
  sql_output = gr.Code(
228
- label="生成的SQL",
229
  language="sql",
230
- lines=6
 
 
 
 
 
 
 
 
231
  )
232
 
233
  with gr.Row():
234
  log_output = gr.Textbox(
235
- label="執行日誌",
236
  lines=4,
237
  interactive=False
238
  )
239
 
240
- # 綁定事件
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  submit_btn.click(
242
- fn=get_sql_query,
243
  inputs=question_input,
244
- outputs=[sql_output, log_output]
245
  )
246
 
247
- # 示例
248
- gr.Examples(
249
- examples=[
250
- "2024年銷售額最高的產品",
251
- "最近30天的訂單統計",
252
- "每個客戶的訂單數量",
253
- "庫存不足的商品列表"
254
- ],
255
- inputs=question_input
 
256
  )
257
 
258
- print("--- 應用啟動完成 ---")
259
- print("--- 訪問地址: http://localhost:7860 ---")
260
  if __name__ == "__main__":
261
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
2
  import requests
3
  import json
4
  import os
5
+ import re
6
+ import sqlite3
7
+ import pandas as pd
8
+ from datetime import datetime
9
  from datasets import load_dataset
10
  from sentence_transformers import SentenceTransformer, util
11
  import torch
12
  from huggingface_hub import hf_hub_download
13
+ from typing import List, Dict, Tuple, Optional
14
 
15
+ # ==================== 配置區 ====================
16
+ HF_TOKEN = os.environ.get("HF_TOKEN", "您的_HuggingFace_Token")
17
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
18
+ SIMILARITY_THRESHOLD = 0.75 # 相似度閾值
19
 
20
+ # 多個備用LLM模型(保證可用性)
21
  LLM_MODELS = [
22
+ "https://api-inference.huggingface.co/models/gpt2",
23
+ "https://api-inference.huggingface.co/models/distilgpt2",
24
  "https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
25
  ]
26
 
27
+ # 數據庫連接配置(可選)
28
+ DB_CONFIG = {
29
+ "enabled": False, # 設置為True啟用真實數據庫連接
30
+ "path": "您的數據庫路徑.db",
31
+ "test_queries": True # 是否啟用SQL測試功能
32
+ }
33
+
34
+ print("=" * 50)
35
+ print("🚀 智能 Text-to-SQL 系統啟動中...")
36
+ print("=" * 50)
37
 
38
+ # ==================== 工具函數 ====================
39
+ def get_current_time():
40
+ """獲取當前時間字符串"""
41
+ return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
42
 
43
+ def safe_json_load(data, default=None):
44
+ """安全的JSON解析"""
45
+ try:
46
+ return json.loads(data) if isinstance(data, str) else data
47
+ except (json.JSONDecodeError, TypeError):
48
+ return default
49
+
50
+ def validate_sql(sql_query: str) -> Dict:
51
+ """驗證SQL語句的安全性"""
52
+ security_issues = []
53
 
54
+ # 檢查危險操作
55
+ dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
56
+ for keyword in dangerous_keywords:
57
+ if f" {keyword} " in sql_query.upper():
58
+ security_issues.append(f"發現危險操作: {keyword}")
59
 
60
+ # 檢查基本語法
61
+ if "SELECT" not in sql_query.upper():
62
+ security_issues.append("缺少SELECT語句")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ if "FROM" not in sql_query.upper():
65
+ security_issues.append("缺少FROM子句")
66
 
67
+ return {
68
+ "valid": len(security_issues) == 0,
69
+ "issues": security_issues,
70
+ "is_safe": len([i for i in security_issues if '危險' in i]) == 0
71
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ def execute_test_query(sql_query: str) -> Tuple[bool, str]:
74
+ """執行測試查詢(可選功能)"""
75
+ if not DB_CONFIG["enabled"]:
76
+ return False, "數據庫連接未啟用"
77
 
78
+ try:
79
+ validation = validate_sql(sql_query)
80
+ if not validation["is_safe"]:
81
+ return False, f"SQL安全檢查失敗: {', '.join(validation['issues'])}"
82
+
83
+ # 連接數據庫並執行
84
+ conn = sqlite3.connect(DB_CONFIG["path"])
85
+ df = pd.read_sql_query(sql_query, conn)
86
+ conn.close()
87
+
88
+ if len(df) == 0:
89
+ return True, "✅ SQL執行成功,但返回0條數據\n💡 可能原因: 條件太嚴格或數據不存在"
90
+ else:
91
+ sample_info = f"✅ SQL執行成功,返回 {len(df)} 條數據\n"
92
+ sample_info += f"📊 前3條數據:\n{df.head(3).to_string()}"
93
+ return True, sample_info
94
+
95
+ except Exception as e:
96
+ return False, f"❌ SQL執行錯誤: {str(e)}"
97
 
98
+ # ==================== 數據加載模塊 ====================
99
+ class DataLoader:
100
+ def __init__(self, hf_token: str):
101
+ self.hf_token = hf_token
102
+ self.questions = []
103
+ self.sql_answers = []
104
+ self.schema_data = {}
105
+
106
+ def load_dataset(self) -> bool:
107
+ """加載問答數據集"""
 
 
 
 
108
  try:
109
+ print(f"[{get_current_time()}] 正在加載數據集 '{DATASET_REPO_ID}'...")
110
+ raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
111
 
112
+ print("正在解析 messages 格式...")
113
+ for item in raw_dataset:
114
+ try:
115
+ if 'messages' in item and len(item['messages']) >= 2:
116
+ user_content = item['messages'][0]['content']
117
+ assistant_content = item['messages'][1]['content']
118
+
119
+ # 提取問題
120
+ question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
121
+ question = question_match.group(1).strip() if question_match else user_content
122
+
123
+ # 提取SQL
124
+ sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
125
+ if sql_match:
126
+ sql_query = sql_match.group(1).strip()
127
+ sql_query = re.sub(r'^sql\s*', '', sql_query)
128
+ sql_query = re.sub(r'```sql|```', '', sql_query).strip()
129
+ else:
130
+ sql_query = assistant_content
131
+
132
+ self.questions.append(question)
133
+ self.sql_answers.append(sql_query)
134
+
135
+ except Exception as e:
136
+ continue
137
 
138
+ print(f"成功解析 {len(self.questions)} 條問答範例")
139
+ return True
140
+
141
+ except Exception as e:
142
+ print(f"數據集加載失敗: {e}")
143
+ self.questions = ["系統初始化問題"]
144
+ self.sql_answers = ["SELECT '數據庫連接就緒' AS status;"]
145
+ return False
146
+
147
+ def load_schema(self) -> bool:
148
+ """加載數據庫Schema"""
149
+ try:
150
+ schema_file_path = hf_hub_download(
151
+ repo_id=DATASET_REPO_ID,
152
+ filename="sqlite_schema_FULL.json",
153
+ repo_type='dataset',
154
+ token=self.hf_token
155
+ )
156
+ with open(schema_file_path, 'r', encoding='utf-8') as f:
157
+ self.schema_data = safe_json_load(f.read(), {})
158
+ print("Schema加載成功")
159
+ return True
160
  except Exception as e:
161
+ print(f"Schema加載失敗: {e}")
162
+ self.schema_data = {}
163
+ return False
164
 
165
+ def build_schema_context(self) -> str:
166
+ """構建Schema上下文"""
167
+ if not self.schema_data:
168
+ return "/* 無Schema信息 */"
169
+
170
+ context = "/* 數據庫表結構 */\n"
171
+ for table_name, columns in self.schema_data.items():
172
+ if isinstance(columns, list):
173
+ context += f"\n-- 表: {table_name}\n"
174
+ for col in columns:
175
+ col_name = col.get('name', 'unknown')
176
+ col_type = col.get('type', 'TEXT')
177
+ col_desc = col.get('description', '')
178
+ context += f"-- {col_name} ({col_type}) - {col_desc}\n"
179
+ return context
180
 
181
+ # ==================== LLM模塊 ====================
182
+ class LLMClient:
183
+ def __init__(self, hf_token: str):
184
+ self.hf_token = hf_token
185
 
186
+ def call_llm_api(self, prompt: str, model_urls: List[str] = LLM_MODELS) -> Optional[str]:
187
+ """調用LLM API(多模型備用)"""
188
+ headers = {"Authorization": f"Bearer {self.hf_token}"}
189
+ payload = {
190
+ "inputs": prompt,
191
+ "parameters": {
192
+ "max_new_tokens": 200,
193
+ "temperature": 0.1,
194
+ "do_sample": False
195
+ }
196
+ }
197
+
198
+ for model_url in model_urls:
199
+ try:
200
+ response = requests.post(model_url, headers=headers, json=payload, timeout=20)
201
+
202
+ if response.status_code == 200:
203
+ result = response.json()
204
+ if isinstance(result, list) and len(result) > 0:
205
+ generated_text = result[0]['generated_text'].strip()
206
+ # 清理輸出
207
+ generated_text = re.sub(r'^```sql|```$', '', generated_text).strip()
208
+ return generated_text
209
+
210
+ except Exception as e:
211
+ print(f"模型 {model_url} 調用失敗: {e}")
212
+ continue
213
+
214
+ return None
215
+
216
+ # ==================== 檢索模塊 ====================
217
+ class RetrievalSystem:
218
+ def __init__(self):
219
+ self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
220
+ self.question_embeddings = None
221
 
222
+ def compute_embeddings(self, questions: List[str]) -> None:
223
+ """計算問題向量"""
224
+ if questions:
225
+ print(f"正在為 {len(questions)} 個問題計算向量...")
226
+ self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True, show_progress_bar=False)
227
+ print("向量計算完成")
228
+ else:
229
+ self.question_embeddings = torch.Tensor([])
230
+
231
+ def retrieve_similar(self, user_question: str, top_k: int = 3) -> List[Dict]:
232
+ """檢索相似問題"""
233
+ if self.question_embeddings is None or len(self.question_embeddings) == 0:
234
+ return []
235
+
236
  try:
237
+ question_embedding = self.embedder.encode(user_question, convert_to_tensor=True)
238
+ hits = util.semantic_search(question_embedding, self.question_embeddings, top_k=top_k)
239
+ return hits[0] if hits and hits[0] else []
240
+ except Exception as e:
241
+ print(f"檢索失敗: {e}")
242
+ return []
243
+
244
+ # ==================== 主系統 ====================
245
+ class TextToSQLSystem:
246
+ def __init__(self, hf_token: str):
247
+ self.hf_token = hf_token
248
+ self.data_loader = DataLoader(hf_token)
249
+ self.llm_client = LLMClient(hf_token)
250
+ self.retrieval_system = RetrievalSystem()
251
+
252
+ # 初始化組件
253
+ self.initialize_system()
254
+
255
+ def initialize_system(self):
256
+ """初始化系統組件"""
257
+ print("正在初始化系統組件...")
258
+
259
+ # 加載數據
260
+ self.data_loader.load_dataset()
261
+ self.data_loader.load_schema()
262
+
263
+ # 初始化檢索系統
264
+ self.retrieval_system.compute_embeddings(self.data_loader.questions)
265
+
266
+ self.schema_context = self.data_loader.build_schema_context()
267
+ print("系統初始化完成")
268
+
269
+ def generate_sql(self, user_question: str) -> Tuple[str, str]:
270
+ """生成SQL查詢(主函數)"""
271
+ log_messages = [f"🕒 開始處理: {get_current_time()}"]
272
+
273
+ if not user_question or user_question.strip() == "":
274
+ return "請輸入您的問題。", "錯誤: 問題為空"
275
+
276
+ # 1. 嘗試檢索相似問題
277
+ if len(self.data_loader.questions) > 0:
278
+ hits = self.retrieval_system.retrieve_similar(user_question)
279
 
280
+ if hits:
281
+ best_hit = hits[0]
282
  similarity_score = best_hit['score']
283
+ similar_question = self.data_loader.questions[best_hit['corpus_id']]
284
 
285
+ log_messages.append(f"🔍 檢索到相似問題: '{similar_question}'")
286
+ log_messages.append(f"📊 相似度: {similarity_score:.3f}")
287
 
288
  if similarity_score > SIMILARITY_THRESHOLD:
289
+ sql_result = self.data_loader.sql_answers[best_hit['corpus_id']]
290
+ log_messages.append(f"相似度 > {SIMILARITY_THRESHOLD},直接返回預先SQL")
291
+
292
+ # 驗證SQL安全性
293
+ validation = validate_sql(sql_result)
294
+ if not validation["is_safe"]:
295
+ log_messages.append(f"⚠️ 安全警告: {', '.join(validation['issues'])}")
296
+
297
  return sql_result, "\n".join(log_messages)
298
  else:
299
+ log_messages.append(f"ℹ️ 相似度低於閾值 {SIMILARITY_THRESHOLD}")
300
+
301
+ # 2. LLM生成模式
302
+ log_messages.append("🤖 進入LLM生成模式...")
303
+
304
+ prompt = self.build_llm_prompt(user_question)
305
+ generated_sql = self.llm_client.call_llm_api(prompt)
306
+
307
+ if generated_sql:
308
+ # 清理和驗證生成的SQL
309
+ generated_sql = re.sub(r'^```sql|```$', '', generated_sql).strip()
310
+ validation = validate_sql(generated_sql)
311
+
312
+ if validation["valid"]:
313
+ log_messages.append("✅ LLM生成成功")
314
+ if validation["issues"]:
315
+ log_messages.append(f"ℹ️ 驗證提示: {', '.join(validation['issues'])}")
316
  else:
317
+ log_messages.append("⚠️ LLM生成可能存在问题")
318
 
319
+ return generated_sql, "\n".join(log_messages)
320
+ else:
321
+ # 3. 備用方案
322
+ log_messages.append("❌ 所有LLM模型都失敗,啟用備用方案")
323
+ backup_sql = self.generate_backup_sql(user_question)
324
+ return backup_sql, "\n".join(log_messages)
325
 
326
+ def build_llm_prompt(self, user_question: str) -> str:
327
+ """構建LLM提示詞"""
328
+ return f"""你是一個SQL專家。請根據以下數據庫結構生成SQL查詢。
329
+
330
+ {self.schema_context}
331
 
332
+ 請為以下問題生成準確的SQL查詢:
333
+ {user_question}
334
 
335
+ 要求:
336
+ 1. 只輸出SQL語句
337
+ 2. 不要任何解釋
338
+ 3. 使用正確的語法
339
 
340
+ SQL查詢:"""
341
 
342
+ def generate_backup_sql(self, user_question: str) -> str:
343
+ """生成備用SQL"""
344
+ user_question_lower = user_question.lower()
345
+
346
+ if any(kw in user_question_lower for kw in ['銷售', '業績', '金額', '收入']):
347
+ return "SELECT product_name, SUM(sales_amount) as total_sales FROM sales GROUP BY product_name ORDER BY total_sales DESC LIMIT 10;"
348
+ elif any(kw in user_question_lower for kw in ['客戶', '買家', '用戶']):
349
+ return "SELECT customer_name, COUNT(*) as order_count, SUM(order_amount) as total_spent FROM orders GROUP BY customer_name ORDER BY total_spent DESC;"
350
+ elif any(kw in user_question_lower for kw in ['時間', '日期', '最近', '月份']):
351
+ return "SELECT strftime('%Y-%m', order_date) as month, COUNT(*) as orders, SUM(order_amount) as revenue FROM orders WHERE order_date >= date('now', '-6 months') GROUP BY month ORDER BY month DESC;"
352
+ elif any(kw in user_question_lower for kw in ['產品', '商品', '項目']):
353
+ return "SELECT product_name, category, stock_quantity, price FROM products WHERE stock_quantity > 0 ORDER BY price DESC;"
354
+ else:
355
+ return "SELECT '請重試或提供更詳細的問題' AS status;"
356
+
357
+ # ==================== 初始化系統 ====================
358
+ print("正在初始化Text-to-SQL系統...")
359
+ text_to_sql_system = TextToSQLSystem(HF_TOKEN)
360
+
361
+ # ==================== Gradio界面 ====================
362
+ def process_query(user_question: str, test_query: bool = False) -> Tuple[str, str, str]:
363
+ """處理用戶查詢"""
364
+ sql_result, log_message = text_to_sql_system.generate_sql(user_question)
365
+
366
+ # SQL調試信息
367
+ debug_info = ""
368
+ validation = validate_sql(sql_result)
369
 
370
+ if not validation["valid"]:
371
+ debug_info = "❌ SQL驗證失敗:\n" + "\n".join(validation["issues"])
 
 
 
372
  else:
373
+ debug_info = "✅ SQL語法驗證通過"
 
374
 
375
+ if validation["issues"]:
376
+ debug_info += "\nℹ️ 提示: " + ", ".join(validation["issues"])
 
 
 
 
 
 
 
377
 
378
+ # 如果啟用測試功能
379
+ if test_query and DB_CONFIG["test_queries"]:
380
+ success, test_result = execute_test_query(sql_result)
381
+ debug_info += f"\n\n🔧 測試結果:\n{test_result}"
382
+
383
+ return sql_result, debug_info, log_message
384
 
385
+ # 創建界面
386
+ with gr.Blocks(
387
+ title="智能Text-to-SQL系統",
388
+ theme=gr.themes.Soft(),
389
+ css="""
390
+ .gradio-container { max-width: 1000px; margin: 0 auto; }
391
+ .success { color: green; }
392
+ .warning { color: orange; }
393
+ .error { color: red; }
394
+ """
395
+ ) as demo:
396
 
397
+ gr.Markdown("# 🚀 智能 Text-to-SQL 系統")
398
+ gr.Markdown("輸入自然語言問題,自動生成並驗證SQL查詢")
 
 
 
 
399
 
400
  with gr.Row():
401
+ with gr.Column(scale=3):
402
+ question_input = gr.Textbox(
403
+ label="📝 您的問題",
404
+ placeholder="例如:查詢2024年銷售額最高的產品",
405
+ lines=2,
406
+ max_lines=4
407
+ )
408
+
409
+ with gr.Row():
410
+ submit_btn = gr.Button("🚀 生成SQL", variant="primary")
411
+ test_btn = gr.Button("🔧 測試SQL", variant="secondary")
412
+ clear_btn = gr.Button("🗑️ 清除", variant="secondary")
413
 
414
  with gr.Row():
415
  sql_output = gr.Code(
416
+ label="📊 生成的SQL",
417
  language="sql",
418
+ lines=6,
419
+ interactive=True
420
+ )
421
+
422
+ with gr.Row():
423
+ debug_output = gr.Textbox(
424
+ label="🔍 SQL調試信息",
425
+ lines=4,
426
+ interactive=False
427
  )
428
 
429
  with gr.Row():
430
  log_output = gr.Textbox(
431
+ label="📋 執行日誌",
432
  lines=4,
433
  interactive=False
434
  )
435
 
436
+ # 示例問題
437
+ gr.Examples(
438
+ examples=[
439
+ "2024年銷售額最高的5個產品",
440
+ "最近30天每個客戶的訂單數量",
441
+ "庫存不足的商品列表",
442
+ "比較2023年和2024年的月度銷售額",
443
+ "付款不及時的客戶統計"
444
+ ],
445
+ inputs=question_input,
446
+ label="💡 示例問題"
447
+ )
448
+
449
+ # 事件處理
450
  submit_btn.click(
451
+ fn=lambda q: process_query(q, False),
452
  inputs=question_input,
453
+ outputs=[sql_output, debug_output, log_output]
454
  )
455
 
456
+ test_btn.click(
457
+ fn=lambda q: process_query(q, True),
458
+ inputs=question_input,
459
+ outputs=[sql_output, debug_output, log_output]
460
+ )
461
+
462
+ clear_btn.click(
463
+ fn=lambda: ["", "", ""],
464
+ inputs=[],
465
+ outputs=[sql_output, debug_output, log_output]
466
  )
467
 
468
+ # ==================== 啟動應用 ====================
 
469
  if __name__ == "__main__":
470
+ print("=" * 50)
471
+ print("🌐 啟動Gradio Web界面...")
472
+ print("📍 本地訪問: http://localhost:7860")
473
+ print("🔄 如果需要公網訪問,設置 share=True")
474
+ print("=" * 50)
475
+
476
+ demo.launch(
477
+ server_name="0.0.0.0",
478
+ server_port=7860,
479
+ share=False,
480
+ show_error=True
481
+ )