Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,14 +14,7 @@ import numpy as np
|
|
| 14 |
# ==================== 配置區 ====================
|
| 15 |
HF_TOKEN = os.environ.get("HF_TOKEN", None) # 建議從環境變數讀取
|
| 16 |
DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
|
| 17 |
-
SIMILARITY_THRESHOLD = 0.
|
| 18 |
-
|
| 19 |
-
# 多個備用LLM模型 (注意:在當前邏輯中並未使用)
|
| 20 |
-
LLM_MODELS = [
|
| 21 |
-
"https://api-inference.huggingface.co/models/gpt2",
|
| 22 |
-
"https://api-inference.huggingface.co/models/distilgpt2",
|
| 23 |
-
"https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
|
| 24 |
-
]
|
| 25 |
|
| 26 |
print("=" * 60)
|
| 27 |
print("🤖 智能 Text-to-SQL 系統啟動中...")
|
|
@@ -45,67 +38,48 @@ def validate_sql(sql_query: str) -> Dict:
|
|
| 45 |
security_issues = []
|
| 46 |
sql_upper = sql_clean.upper()
|
| 47 |
|
| 48 |
-
# 檢查危險操作
|
| 49 |
dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
|
| 50 |
for keyword in dangerous_keywords:
|
| 51 |
if f" {keyword} " in f" {sql_upper} ":
|
| 52 |
security_issues.append(f"危險操作: {keyword}")
|
| 53 |
|
| 54 |
-
# 檢查基本語法
|
| 55 |
if "SELECT" not in sql_upper:
|
| 56 |
security_issues.append("缺少SELECT")
|
| 57 |
-
|
| 58 |
if "FROM" not in sql_upper:
|
| 59 |
security_issues.append("缺少FROM")
|
| 60 |
|
| 61 |
is_valid = not security_issues
|
| 62 |
is_safe = all('危險' not in issue for issue in security_issues)
|
| 63 |
|
| 64 |
-
return {
|
| 65 |
-
"valid": is_valid,
|
| 66 |
-
"issues": security_issues,
|
| 67 |
-
"is_safe": is_safe,
|
| 68 |
-
"empty": False
|
| 69 |
-
}
|
| 70 |
|
| 71 |
def analyze_question_type(question: str) -> Dict:
|
| 72 |
-
"""
|
| 73 |
question_lower = question.lower()
|
| 74 |
|
| 75 |
analysis = {
|
| 76 |
"type": "unknown",
|
| 77 |
"keywords": [],
|
| 78 |
-
"has_count":
|
| 79 |
-
"has_date":
|
| 80 |
-
"has_group":
|
| 81 |
-
"
|
| 82 |
}
|
| 83 |
|
| 84 |
-
#
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
# 特殊檢測
|
| 100 |
-
analysis["has_count"] = "count" in analysis["keywords"]
|
| 101 |
-
analysis["has_date"] = "time" in analysis["keywords"]
|
| 102 |
-
analysis["has_group"] = any(word in question_lower for word in ["每", "各", "group", "每個"])
|
| 103 |
-
analysis["has_comparison"] = "comparison" in analysis["keywords"]
|
| 104 |
-
|
| 105 |
-
# 確定主要類型
|
| 106 |
-
if analysis["keywords"]:
|
| 107 |
-
analysis["type"] = analysis["keywords"][0]
|
| 108 |
-
|
| 109 |
return analysis
|
| 110 |
|
| 111 |
# ==================== 完整數據加載模塊 ====================
|
|
@@ -114,77 +88,47 @@ class CompleteDataLoader:
|
|
| 114 |
self.hf_token = hf_token
|
| 115 |
self.questions = []
|
| 116 |
self.sql_answers = []
|
| 117 |
-
self.sql_quality = []
|
| 118 |
self.schema_data = {}
|
| 119 |
|
| 120 |
def load_complete_dataset(self) -> bool:
|
| 121 |
-
"""加載完整數據集(包括空白SQL)"""
|
| 122 |
try:
|
| 123 |
print(f"[{get_current_time()}] 正在加載完整數據集 '{DATASET_REPO_ID}'...")
|
| 124 |
raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
|
| 125 |
|
| 126 |
-
print("解析全部 messages 格式...")
|
| 127 |
-
total_count, empty_count, valid_count = 0, 0, 0
|
| 128 |
-
|
| 129 |
for item in raw_dataset:
|
| 130 |
try:
|
| 131 |
if 'messages' in item and len(item['messages']) >= 2:
|
| 132 |
user_content = item['messages'][0]['content']
|
| 133 |
assistant_content = item['messages'][1]['content']
|
| 134 |
|
| 135 |
-
# 提取問題
|
| 136 |
question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
|
| 137 |
question = question_match.group(1).strip() if question_match else user_content
|
| 138 |
|
| 139 |
-
# 提取SQL
|
| 140 |
sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
|
| 141 |
-
if sql_match
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
sql_query = re.sub(r'```sql|```', '', sql_query).strip()
|
| 145 |
-
else:
|
| 146 |
-
sql_query = assistant_content
|
| 147 |
-
|
| 148 |
-
# 保存所有數據
|
| 149 |
self.questions.append(question)
|
| 150 |
self.sql_answers.append(sql_query)
|
| 151 |
-
|
| 152 |
-
# 評估SQL質量
|
| 153 |
-
validation = validate_sql(sql_query)
|
| 154 |
-
quality_score = 1.0 if validation["valid"] else 0.3
|
| 155 |
-
self.sql_quality.append(quality_score)
|
| 156 |
-
|
| 157 |
-
total_count += 1
|
| 158 |
-
if validation["empty"]:
|
| 159 |
-
empty_count += 1
|
| 160 |
-
if validation["valid"]:
|
| 161 |
-
valid_count += 1
|
| 162 |
except Exception:
|
| 163 |
continue
|
| 164 |
|
| 165 |
-
print(f"數據加載完成: 總數 {
|
| 166 |
return True
|
| 167 |
-
|
| 168 |
except Exception as e:
|
| 169 |
print(f"數據集加載失敗: {e}")
|
| 170 |
return False
|
| 171 |
|
| 172 |
def load_schema(self) -> bool:
|
| 173 |
-
"""加載數據庫Schema"""
|
| 174 |
try:
|
| 175 |
-
schema_file_path = hf_hub_download(
|
| 176 |
-
repo_id=DATASET_REPO_ID,
|
| 177 |
-
filename="sqlite_schema_FULL.json",
|
| 178 |
-
repo_type='dataset',
|
| 179 |
-
token=self.hf_token
|
| 180 |
-
)
|
| 181 |
with open(schema_file_path, 'r', encoding='utf-8') as f:
|
| 182 |
self.schema_data = json.load(f)
|
| 183 |
print("Schema加載成功")
|
| 184 |
return True
|
| 185 |
except Exception as e:
|
| 186 |
print(f"Schema加載失敗: {e}")
|
| 187 |
-
self.schema_data = {}
|
| 188 |
return False
|
| 189 |
|
| 190 |
# ==================== 檢索系統 ====================
|
|
@@ -197,19 +141,18 @@ class RetrievalSystem:
|
|
| 197 |
print(f"SentenceTransformer 模型加載失敗: {e}")
|
| 198 |
self.embedder = None
|
| 199 |
|
| 200 |
-
def compute_embeddings(self, questions: List[str])
|
| 201 |
if self.embedder and questions:
|
| 202 |
print(f"正在為 {len(questions)} 個問題計算向量...")
|
| 203 |
self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
|
| 204 |
print("向量計算完成")
|
| 205 |
|
| 206 |
-
def retrieve_similar(self, user_question: str, top_k: int =
|
| 207 |
-
if self.embedder is None or self.question_embeddings is None
|
| 208 |
-
return []
|
| 209 |
try:
|
| 210 |
question_embedding = self.embedder.encode(user_question, convert_to_tensor=True)
|
| 211 |
hits = util.semantic_search(question_embedding, self.question_embeddings, top_k=top_k)
|
| 212 |
-
return hits[0] if hits
|
| 213 |
except Exception as e:
|
| 214 |
print(f"檢索錯誤: {e}")
|
| 215 |
return []
|
|
@@ -223,138 +166,131 @@ class CompleteTextToSQLSystem:
|
|
| 223 |
self.initialize_system()
|
| 224 |
|
| 225 |
def initialize_system(self):
|
| 226 |
-
"""初始化系統組件"""
|
| 227 |
print("正在初始化完整數據系統...")
|
| 228 |
-
|
| 229 |
self.data_loader.load_complete_dataset()
|
| 230 |
self.data_loader.load_schema()
|
| 231 |
-
|
| 232 |
-
# 為所有問題計算向量
|
| 233 |
if self.data_loader.questions:
|
| 234 |
self.retrieval_system.compute_embeddings(self.data_loader.questions)
|
| 235 |
-
|
| 236 |
print(f"系統初始化完成,載入問題總數: {len(self.data_loader.questions)}")
|
| 237 |
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
if
|
| 242 |
-
return {}
|
| 243 |
-
|
| 244 |
-
tables = {}
|
| 245 |
-
for table_name, columns_list in self.data_loader.schema_data.items():
|
| 246 |
-
if isinstance(columns_list, list):
|
| 247 |
-
column_names = [col["name"] for col in columns_list if "name" in col]
|
| 248 |
-
tables[table_name] = column_names
|
| 249 |
-
|
| 250 |
-
return tables
|
| 251 |
-
|
| 252 |
-
def extract_number(self, text: str, default: int = 10) -> int:
|
| 253 |
-
"""從文字中提取數字"""
|
| 254 |
-
numbers = re.findall(r'\d+', text)
|
| 255 |
-
return int(numbers[0]) if numbers else default
|
| 256 |
|
| 257 |
def generate_sql_from_question(self, question: str, analysis: Dict) -> str:
|
| 258 |
-
"""
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
group_mapping = {'A': 'TA', 'B': 'TB', 'C': 'TC', 'D': 'TD'}
|
| 268 |
-
table_suffix = group_mapping.get(group, 'TA')
|
| 269 |
-
table_name = f"JobTimeline_{table_suffix}"
|
| 270 |
-
|
| 271 |
-
if "昨天" in question_lower:
|
| 272 |
-
return f"SELECT COUNT(*) as 完成數量 FROM {table_name} WHERE DATE(end_time) = DATE('now','-1 day');"
|
| 273 |
-
elif "每月" in question_lower:
|
| 274 |
-
year_match = re.search(r'(\d{4})年?', question_lower)
|
| 275 |
-
year = year_match.group(1) if year_match else datetime.now().strftime('%Y')
|
| 276 |
-
return f"""SELECT strftime('%Y-%m', end_time) as 月份, COUNT(*) as 完成數量 FROM {table_name} WHERE strftime('%Y', end_time) = '{year}' AND end_time IS NOT NULL GROUP BY strftime('%Y-%m', end_time) ORDER BY 月份;"""
|
| 277 |
-
return "SELECT strftime('%Y-%m', jt.end_time) as 月份, COUNT(*) as 完成數量 FROM JobTimeline jt WHERE jt.end_time IS NOT NULL GROUP BY strftime('%Y-%m', jt.end_time) ORDER BY 月份;"
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
-
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
def generate_sql(self, user_question: str) -> Tuple[str, str]:
|
| 312 |
-
"""主流程:生成SQL查詢"""
|
| 313 |
log_messages = [f"⏰ {get_current_time()} 開始處理"]
|
| 314 |
|
| 315 |
if not user_question or not user_question.strip():
|
| 316 |
return "請輸入您的問題。", "錯誤: 問題為空"
|
| 317 |
|
| 318 |
# 1. 檢索最相似的問題
|
| 319 |
-
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
corpus_id = best_hit['corpus_id']
|
| 326 |
similar_question = self.data_loader.questions[corpus_id]
|
| 327 |
original_sql = self.data_loader.sql_answers[corpus_id]
|
| 328 |
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
repaired_sql = self.repair_empty_sql(original_sql, user_question, similar_question)
|
| 334 |
-
log_messages.append(f"✅ 相似度高於閾值 {SIMILARITY_THRESHOLD},採用檢索結果。")
|
| 335 |
-
return repaired_sql, "\n".join(log_messages)
|
| 336 |
else:
|
| 337 |
-
log_messages.append(f"
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
|
|
|
|
|
|
| 341 |
analysis = analyze_question_type(user_question)
|
| 342 |
-
|
|
|
|
| 343 |
|
| 344 |
-
log_messages.append(f"📋
|
| 345 |
log_messages.append("✅ 智能生成完成。")
|
| 346 |
-
|
| 347 |
return intelligent_sql, "\n".join(log_messages)
|
| 348 |
|
| 349 |
# ==================== 初始化系統 ====================
|
| 350 |
-
print("準備初始化 Text-to-SQL 系統...")
|
| 351 |
-
# 檢查 HF_TOKEN 是否存在
|
| 352 |
if HF_TOKEN is None:
|
| 353 |
-
print("\n" + "="*60)
|
| 354 |
-
print("⚠️ 警告: Hugging Face Token 未設置。")
|
| 355 |
-
print("請在環境變數中設定 HF_TOKEN 才能從私人數據集下載資料。")
|
| 356 |
-
print("="*60 + "\n")
|
| 357 |
-
# 這裡可以選擇退出或繼續,但下載會失敗
|
| 358 |
text_to_sql_system = None
|
| 359 |
else:
|
| 360 |
text_to_sql_system = CompleteTextToSQLSystem(HF_TOKEN)
|
|
@@ -366,51 +302,35 @@ def process_query(user_question: str) -> Tuple[str, str, str]:
|
|
| 366 |
return "系統未初始化", error_msg, error_msg
|
| 367 |
|
| 368 |
sql_result, log_message = text_to_sql_system.generate_sql(user_question)
|
| 369 |
-
return sql_result, "✅
|
| 370 |
|
| 371 |
with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
|
| 372 |
-
gr.Markdown("# 🚀 智慧 Text-to-SQL 系統")
|
| 373 |
-
gr.Markdown("📊 **模式**:
|
| 374 |
|
| 375 |
with gr.Row():
|
| 376 |
-
question_input = gr.Textbox(
|
| 377 |
-
label="📝 請在此輸入您的問題",
|
| 378 |
-
placeholder="例如:2023年每月完成多少份報告? 或 哪個客戶的訂單總金額最高?",
|
| 379 |
-
lines=3,
|
| 380 |
-
scale=4
|
| 381 |
-
)
|
| 382 |
submit_btn = gr.Button("🚀 生成SQL", variant="primary", scale=1)
|
| 383 |
|
| 384 |
with gr.Accordion("🔍 結果與日誌", open=True):
|
| 385 |
-
sql_output = gr.Code(
|
| 386 |
-
label="📊 生成的SQL查詢",
|
| 387 |
-
language="sql",
|
| 388 |
-
lines=8
|
| 389 |
-
)
|
| 390 |
status_output = gr.Textbox(label="🔍 執行狀態", interactive=False)
|
| 391 |
-
log_output = gr.Textbox(label="📋 詳細日誌", lines=
|
| 392 |
|
| 393 |
-
# 預設範例
|
| 394 |
gr.Examples(
|
| 395 |
examples=[
|
| 396 |
-
"
|
| 397 |
-
"
|
| 398 |
-
"
|
| 399 |
-
"
|
| 400 |
-
"統計所有評級的分佈"
|
| 401 |
],
|
| 402 |
inputs=question_input
|
| 403 |
)
|
| 404 |
|
| 405 |
-
submit_btn.click(
|
| 406 |
-
process_query,
|
| 407 |
-
inputs=question_input,
|
| 408 |
-
outputs=[sql_output, status_output, log_output]
|
| 409 |
-
)
|
| 410 |
-
|
| 411 |
if __name__ == "__main__":
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
print("無法啟動 Gradio,因為系統初始化失敗。")
|
| 415 |
-
else:
|
| 416 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# ==================== 配置區 ====================
|
| 15 |
HF_TOKEN = os.environ.get("HF_TOKEN", None) # 建議從環境變數讀取
|
| 16 |
DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
|
| 17 |
+
SIMILARITY_THRESHOLD = 0.65 # 適度提高閾值,確保檢索到的問題意圖更一致
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
print("=" * 60)
|
| 20 |
print("🤖 智能 Text-to-SQL 系統啟動中...")
|
|
|
|
| 38 |
security_issues = []
|
| 39 |
sql_upper = sql_clean.upper()
|
| 40 |
|
|
|
|
| 41 |
dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
|
| 42 |
for keyword in dangerous_keywords:
|
| 43 |
if f" {keyword} " in f" {sql_upper} ":
|
| 44 |
security_issues.append(f"危險操作: {keyword}")
|
| 45 |
|
|
|
|
| 46 |
if "SELECT" not in sql_upper:
|
| 47 |
security_issues.append("缺少SELECT")
|
|
|
|
| 48 |
if "FROM" not in sql_upper:
|
| 49 |
security_issues.append("缺少FROM")
|
| 50 |
|
| 51 |
is_valid = not security_issues
|
| 52 |
is_safe = all('危險' not in issue for issue in security_issues)
|
| 53 |
|
| 54 |
+
return {"valid": is_valid, "issues": security_issues, "is_safe": is_safe, "empty": False}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def analyze_question_type(question: str) -> Dict:
|
| 57 |
+
"""增強的問題分析 - 更精確的意圖識別"""
|
| 58 |
question_lower = question.lower()
|
| 59 |
|
| 60 |
analysis = {
|
| 61 |
"type": "unknown",
|
| 62 |
"keywords": [],
|
| 63 |
+
"has_count": "多少" in question_lower or "幾個" in question_lower or "數量" in question_lower,
|
| 64 |
+
"has_date": "時間" in question_lower or "日期" in question_lower or "月份" in question_lower or "年" in question_lower,
|
| 65 |
+
"has_group": "每" in question_lower or "各" in question_lower or "分組" in question_lower,
|
| 66 |
+
"specific_intent": "general_query" # 新增:具體意圖,預設為通用查詢
|
| 67 |
}
|
| 68 |
|
| 69 |
+
# **更精確的意圖識別**
|
| 70 |
+
if "每月" in question_lower and ("完成" in question_lower or "報告" in question_lower or "工作單" in question_lower):
|
| 71 |
+
analysis["specific_intent"] = "monthly_completion_count"
|
| 72 |
+
analysis["type"] = "time_series"
|
| 73 |
+
elif ("評級" in question_lower or "pass" in question_lower or "fail" in question_lower) and ("統計" in question_lower or "分佈" in question_lower or "多少" in question_lower):
|
| 74 |
+
analysis["specific_intent"] = "rating_distribution"
|
| 75 |
+
analysis["type"] = "statistics"
|
| 76 |
+
elif "金額" in question_lower and ("最高" in question_lower or "top" in question_lower or "排名" in question_lower):
|
| 77 |
+
analysis["specific_intent"] = "amount_ranking"
|
| 78 |
+
analysis["type"] = "ranking"
|
| 79 |
+
elif ("公司" in question_lower or "客戶" in question_lower or "申請方" in question_lower) and ("統計" in question_lower or "數量" in question_lower or "排名" in question_lower):
|
| 80 |
+
analysis["specific_intent"] = "company_statistics"
|
| 81 |
+
analysis["type"] = "statistics"
|
| 82 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
return analysis
|
| 84 |
|
| 85 |
# ==================== 完整數據加載模塊 ====================
|
|
|
|
| 88 |
self.hf_token = hf_token
|
| 89 |
self.questions = []
|
| 90 |
self.sql_answers = []
|
| 91 |
+
self.sql_quality = []
|
| 92 |
self.schema_data = {}
|
| 93 |
|
| 94 |
def load_complete_dataset(self) -> bool:
|
|
|
|
| 95 |
try:
|
| 96 |
print(f"[{get_current_time()}] 正在加載完整數據集 '{DATASET_REPO_ID}'...")
|
| 97 |
raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
|
| 98 |
|
|
|
|
|
|
|
|
|
|
| 99 |
for item in raw_dataset:
|
| 100 |
try:
|
| 101 |
if 'messages' in item and len(item['messages']) >= 2:
|
| 102 |
user_content = item['messages'][0]['content']
|
| 103 |
assistant_content = item['messages'][1]['content']
|
| 104 |
|
|
|
|
| 105 |
question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
|
| 106 |
question = question_match.group(1).strip() if question_match else user_content
|
| 107 |
|
|
|
|
| 108 |
sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
|
| 109 |
+
sql_query = sql_match.group(1).strip() if sql_match else assistant_content
|
| 110 |
+
sql_query = re.sub(r'```sql|```', '', sql_query).strip()
|
| 111 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
self.questions.append(question)
|
| 113 |
self.sql_answers.append(sql_query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
except Exception:
|
| 115 |
continue
|
| 116 |
|
| 117 |
+
print(f"數據加載完成: 總數 {len(self.questions)}")
|
| 118 |
return True
|
|
|
|
| 119 |
except Exception as e:
|
| 120 |
print(f"數據集加載失敗: {e}")
|
| 121 |
return False
|
| 122 |
|
| 123 |
def load_schema(self) -> bool:
|
|
|
|
| 124 |
try:
|
| 125 |
+
schema_file_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename="sqlite_schema_FULL.json", repo_type='dataset', token=self.hf_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
with open(schema_file_path, 'r', encoding='utf-8') as f:
|
| 127 |
self.schema_data = json.load(f)
|
| 128 |
print("Schema加載成功")
|
| 129 |
return True
|
| 130 |
except Exception as e:
|
| 131 |
print(f"Schema加載失敗: {e}")
|
|
|
|
| 132 |
return False
|
| 133 |
|
| 134 |
# ==================== 檢索系統 ====================
|
|
|
|
| 141 |
print(f"SentenceTransformer 模型加載失敗: {e}")
|
| 142 |
self.embedder = None
|
| 143 |
|
| 144 |
+
def compute_embeddings(self, questions: List[str]):
|
| 145 |
if self.embedder and questions:
|
| 146 |
print(f"正在為 {len(questions)} 個問題計算向量...")
|
| 147 |
self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
|
| 148 |
print("向量計算完成")
|
| 149 |
|
| 150 |
+
def retrieve_similar(self, user_question: str, top_k: int = 1) -> List[Dict]:
|
| 151 |
+
if self.embedder is None or self.question_embeddings is None: return []
|
|
|
|
| 152 |
try:
|
| 153 |
question_embedding = self.embedder.encode(user_question, convert_to_tensor=True)
|
| 154 |
hits = util.semantic_search(question_embedding, self.question_embeddings, top_k=top_k)
|
| 155 |
+
return hits[0] if hits else []
|
| 156 |
except Exception as e:
|
| 157 |
print(f"檢索錯誤: {e}")
|
| 158 |
return []
|
|
|
|
| 166 |
self.initialize_system()
|
| 167 |
|
| 168 |
def initialize_system(self):
|
|
|
|
| 169 |
print("正在初始化完整數據系統...")
|
|
|
|
| 170 |
self.data_loader.load_complete_dataset()
|
| 171 |
self.data_loader.load_schema()
|
|
|
|
|
|
|
| 172 |
if self.data_loader.questions:
|
| 173 |
self.retrieval_system.compute_embeddings(self.data_loader.questions)
|
|
|
|
| 174 |
print(f"系統初始化完成,載入問題總數: {len(self.data_loader.questions)}")
|
| 175 |
|
| 176 |
+
def extract_year(self, text: str) -> str:
|
| 177 |
+
"""從文字中提取年份,若無則返回當年"""
|
| 178 |
+
year_match = re.search(r'(\d{4})', text)
|
| 179 |
+
return year_match.group(1) if year_match else datetime.now().strftime('%Y')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
def generate_sql_from_question(self, question: str, analysis: Dict) -> str:
|
| 182 |
+
"""通用SQL生成器 (作為最終備用)"""
|
| 183 |
+
# 此函數現在作為無法識別具體意圖時的通用後備方案
|
| 184 |
+
return f"""-- 通用查詢範本
|
| 185 |
+
SELECT
|
| 186 |
+
JobNo as 工作單號,
|
| 187 |
+
ApplicantName as 申請方,
|
| 188 |
+
OverallRating as 評級
|
| 189 |
+
FROM TSR53SampleDescription
|
| 190 |
+
LIMIT 20;"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
+
def intelligent_repair_sql(self, user_question: str, similar_question: str) -> str:
|
| 193 |
+
"""智能修復SQL - 基於當前使用者問題的意圖"""
|
| 194 |
+
analysis = analyze_question_type(user_question)
|
| 195 |
+
intent = analysis["specific_intent"]
|
| 196 |
+
|
| 197 |
+
comment = f"-- 根據類似問題 '{similar_question}' (原SQL無效) 進行智能修復\n"
|
| 198 |
+
|
| 199 |
+
if intent == "monthly_completion_count":
|
| 200 |
+
year = self.extract_year(user_question)
|
| 201 |
+
return comment + f"""-- 查詢 {year} 年每月完成的工作單數量
|
| 202 |
+
SELECT
|
| 203 |
+
strftime('%Y-%m', jt.end_time) as 月份,
|
| 204 |
+
COUNT(*) as 完成數量
|
| 205 |
+
FROM JobTimeline jt
|
| 206 |
+
WHERE strftime('%Y', jt.end_time) = '{year}' AND jt.end_time IS NOT NULL
|
| 207 |
+
GROUP BY strftime('%Y-%m', jt.end_time)
|
| 208 |
+
ORDER BY 月份;"""
|
| 209 |
|
| 210 |
+
elif intent == "rating_distribution":
|
| 211 |
+
return comment + """-- 查詢評級分佈統計
|
| 212 |
+
SELECT
|
| 213 |
+
OverallRating as 評級,
|
| 214 |
+
COUNT(*) as 數量,
|
| 215 |
+
ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM TSR53SampleDescription WHERE OverallRating IS NOT NULL), 2) as 百分比
|
| 216 |
+
FROM TSR53SampleDescription
|
| 217 |
+
WHERE OverallRating IS NOT NULL
|
| 218 |
+
GROUP BY OverallRating
|
| 219 |
+
ORDER BY 數量 DESC;"""
|
| 220 |
|
| 221 |
+
elif intent == "amount_ranking":
|
| 222 |
+
return comment + """-- 查詢工作單金額排名
|
| 223 |
+
WITH JobTotalAmount AS (
|
| 224 |
+
SELECT JobNo, SUM(LocalAmount) AS TotalAmount
|
| 225 |
+
FROM (SELECT DISTINCT JobNo, InvoiceCreditNoteNo, LocalAmount FROM TSR53Invoice WHERE LocalAmount IS NOT NULL)
|
| 226 |
+
GROUP BY JobNo
|
| 227 |
+
)
|
| 228 |
+
SELECT
|
| 229 |
+
jta.JobNo as 工作單號,
|
| 230 |
+
sd.ApplicantName as 申請方,
|
| 231 |
+
jta.TotalAmount as 總金額
|
| 232 |
+
FROM JobTotalAmount jta
|
| 233 |
+
JOIN TSR53SampleDescription sd ON sd.JobNo = jta.JobNo
|
| 234 |
+
ORDER BY jta.TotalAmount DESC
|
| 235 |
+
LIMIT 10;"""
|
| 236 |
|
| 237 |
+
elif intent == "company_statistics":
|
| 238 |
+
return comment + """-- 查詢申請方工作單統計
|
| 239 |
+
SELECT
|
| 240 |
+
ApplicantName as 申請方名稱,
|
| 241 |
+
COUNT(*) as 工作單數量
|
| 242 |
+
FROM TSR53SampleDescription
|
| 243 |
+
WHERE ApplicantName IS NOT NULL
|
| 244 |
+
GROUP BY ApplicantName
|
| 245 |
+
ORDER BY 工作單數量 DESC
|
| 246 |
+
LIMIT 20;"""
|
| 247 |
+
|
| 248 |
+
# 如果無法判斷具體意圖,使用原始的通用生成邏輯
|
| 249 |
+
return comment + self.generate_sql_from_question(user_question, analysis)
|
| 250 |
|
| 251 |
def generate_sql(self, user_question: str) -> Tuple[str, str]:
|
| 252 |
+
"""主流程:生成SQL查詢 (改進版本)"""
|
| 253 |
log_messages = [f"⏰ {get_current_time()} 開始處理"]
|
| 254 |
|
| 255 |
if not user_question or not user_question.strip():
|
| 256 |
return "請輸入您的問題。", "錯誤: 問題為空"
|
| 257 |
|
| 258 |
# 1. 檢索最相似的問題
|
| 259 |
+
hits = self.retrieval_system.retrieve_similar(user_question)
|
| 260 |
+
|
| 261 |
+
if hits:
|
| 262 |
+
best_hit = hits[0]
|
| 263 |
+
similarity_score = best_hit['score']
|
| 264 |
|
| 265 |
+
log_messages.append(f"🔍 檢索到最相似問題 (相似度: {similarity_score:.3f})")
|
| 266 |
+
|
| 267 |
+
if similarity_score > SIMILARITY_THRESHOLD:
|
| 268 |
corpus_id = best_hit['corpus_id']
|
| 269 |
similar_question = self.data_loader.questions[corpus_id]
|
| 270 |
original_sql = self.data_loader.sql_answers[corpus_id]
|
| 271 |
|
| 272 |
+
validation = validate_sql(original_sql)
|
| 273 |
+
if validation["valid"] and validation["is_safe"]:
|
| 274 |
+
log_messages.append("✅ 相似度高,且原SQL有效,直接採用。")
|
| 275 |
+
return original_sql, "\n".join(log_messages)
|
|
|
|
|
|
|
|
|
|
| 276 |
else:
|
| 277 |
+
log_messages.append(f"⚠️ 相似度高,但原SQL無效 ({', '.join(validation['issues'])})。")
|
| 278 |
+
log_messages.append("🛠️ 啟用智能修復...")
|
| 279 |
+
repaired_sql = self.intelligent_repair_sql(user_question, similar_question)
|
| 280 |
+
return repaired_sql, "\n".join(log_messages)
|
| 281 |
+
|
| 282 |
+
log_messages.append("🤖 未找到高相似度或有效的範本,根據問題直接生成。")
|
| 283 |
analysis = analyze_question_type(user_question)
|
| 284 |
+
# 直接使用修復邏輯來生成,因為它本身就是基於意圖的生成器
|
| 285 |
+
intelligent_sql = self.intelligent_repair_sql(user_question, "無相似問題")
|
| 286 |
|
| 287 |
+
log_messages.append(f"📋 問題意圖分析: {analysis['specific_intent']}")
|
| 288 |
log_messages.append("✅ 智能生成完成。")
|
|
|
|
| 289 |
return intelligent_sql, "\n".join(log_messages)
|
| 290 |
|
| 291 |
# ==================== 初始化系統 ====================
|
|
|
|
|
|
|
| 292 |
if HF_TOKEN is None:
|
| 293 |
+
print("\n" + "="*60 + "\n⚠️ 警告: Hugging Face Token 未設置。\n" + "="*60 + "\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
text_to_sql_system = None
|
| 295 |
else:
|
| 296 |
text_to_sql_system = CompleteTextToSQLSystem(HF_TOKEN)
|
|
|
|
| 302 |
return "系統未初始化", error_msg, error_msg
|
| 303 |
|
| 304 |
sql_result, log_message = text_to_sql_system.generate_sql(user_question)
|
| 305 |
+
return sql_result, "✅ 處理完成", log_message
|
| 306 |
|
| 307 |
with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
|
| 308 |
+
gr.Markdown("# 🚀 智慧 Text-to-SQL 系統 (進階修復版)")
|
| 309 |
+
gr.Markdown("📊 **模式**: 結合「檢索驗證」與「意圖導向生成」,即使資料庫範本有誤也能提供準確查詢。")
|
| 310 |
|
| 311 |
with gr.Row():
|
| 312 |
+
question_input = gr.Textbox(label="📝 請在此輸入您的問題", placeholder="例如:2024年每月完成多少份報告?", lines=3, scale=4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
submit_btn = gr.Button("🚀 生成SQL", variant="primary", scale=1)
|
| 314 |
|
| 315 |
with gr.Accordion("🔍 結果與日誌", open=True):
|
| 316 |
+
sql_output = gr.Code(label="📊 生成的SQL查詢", language="sql", lines=10)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
status_output = gr.Textbox(label="🔍 執行狀態", interactive=False)
|
| 318 |
+
log_output = gr.Textbox(label="📋 詳細日誌", lines=6, interactive=False)
|
| 319 |
|
|
|
|
| 320 |
gr.Examples(
|
| 321 |
examples=[
|
| 322 |
+
"2023 年每月完成多少份報告?",
|
| 323 |
+
"統計一下各種評級的分佈",
|
| 324 |
+
"找出總金額最高的5筆訂單來自哪個申請方",
|
| 325 |
+
"哪個客戶的工作單數量最多?"
|
|
|
|
| 326 |
],
|
| 327 |
inputs=question_input
|
| 328 |
)
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
if __name__ == "__main__":
|
| 331 |
+
if text_to_sql_system:
|
| 332 |
+
print("Gradio 介面啟動中...")
|
|
|
|
|
|
|
| 333 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
| 334 |
+
else:
|
| 335 |
+
print("無法啟動 Gradio,因為系統初始化失敗。")
|
| 336 |
+
|