Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -30,33 +30,33 @@ def validate_sql(sql_query: str) -> Dict:
|
|
| 30 |
"""驗證SQL語句的語法和安全性"""
|
| 31 |
if not sql_query or not sql_query.strip():
|
| 32 |
return {"valid": False, "issues": ["SQL語句為空"], "is_safe": False, "empty": True}
|
| 33 |
-
|
| 34 |
sql_clean = sql_query.strip()
|
| 35 |
if len(sql_clean) < 5:
|
| 36 |
return {"valid": False, "issues": ["SQL過短"], "is_safe": False, "empty": True}
|
| 37 |
-
|
| 38 |
security_issues = []
|
| 39 |
sql_upper = sql_clean.upper()
|
| 40 |
-
|
| 41 |
dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
|
| 42 |
for keyword in dangerous_keywords:
|
| 43 |
if f" {keyword} " in f" {sql_upper} ":
|
| 44 |
security_issues.append(f"危險操作: {keyword}")
|
| 45 |
-
|
| 46 |
if "SELECT" not in sql_upper:
|
| 47 |
security_issues.append("缺少SELECT")
|
| 48 |
if "FROM" not in sql_upper:
|
| 49 |
security_issues.append("缺少FROM")
|
| 50 |
-
|
| 51 |
is_valid = not security_issues
|
| 52 |
is_safe = all('危險' not in issue for issue in security_issues)
|
| 53 |
-
|
| 54 |
return {"valid": is_valid, "issues": security_issues, "is_safe": is_safe, "empty": False}
|
| 55 |
|
| 56 |
def analyze_question_type(question: str) -> Dict:
|
| 57 |
"""增強的問題分析 - 更精確的意圖識別"""
|
| 58 |
question_lower = question.lower()
|
| 59 |
-
|
| 60 |
analysis = {
|
| 61 |
"type": "unknown",
|
| 62 |
"keywords": [],
|
|
@@ -65,7 +65,7 @@ def analyze_question_type(question: str) -> Dict:
|
|
| 65 |
"has_group": "每" in question_lower or "各" in question_lower or "分組" in question_lower,
|
| 66 |
"specific_intent": "general_query" # 新增:具體意圖,預設為通用查詢
|
| 67 |
}
|
| 68 |
-
|
| 69 |
# **更精確的意圖識別**
|
| 70 |
if "每月" in question_lower and ("完成" in question_lower or "報告" in question_lower or "工作單" in question_lower):
|
| 71 |
analysis["specific_intent"] = "monthly_completion_count"
|
|
@@ -79,7 +79,7 @@ def analyze_question_type(question: str) -> Dict:
|
|
| 79 |
elif ("公司" in question_lower or "客戶" in question_lower or "申請方" in question_lower) and ("統計" in question_lower or "數量" in question_lower or "排名" in question_lower):
|
| 80 |
analysis["specific_intent"] = "company_statistics"
|
| 81 |
analysis["type"] = "statistics"
|
| 82 |
-
|
| 83 |
return analysis
|
| 84 |
|
| 85 |
# ==================== 完整數據加載模塊 ====================
|
|
@@ -90,36 +90,42 @@ class CompleteDataLoader:
|
|
| 90 |
self.sql_answers = []
|
| 91 |
self.sql_quality = []
|
| 92 |
self.schema_data = {}
|
| 93 |
-
|
| 94 |
def load_complete_dataset(self) -> bool:
|
| 95 |
try:
|
| 96 |
print(f"[{get_current_time()}] 正在加載完整數據集 '{DATASET_REPO_ID}'...")
|
| 97 |
raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
| 100 |
try:
|
| 101 |
if 'messages' in item and len(item['messages']) >= 2:
|
| 102 |
user_content = item['messages'][0]['content']
|
| 103 |
assistant_content = item['messages'][1]['content']
|
| 104 |
-
|
| 105 |
question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
|
| 106 |
question = question_match.group(1).strip() if question_match else user_content
|
| 107 |
-
|
| 108 |
sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
|
| 109 |
sql_query = sql_match.group(1).strip() if sql_match else assistant_content
|
| 110 |
sql_query = re.sub(r'```sql|```', '', sql_query).strip()
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
| 115 |
continue
|
| 116 |
-
|
| 117 |
-
print(f"數據加載完成:
|
| 118 |
-
return
|
| 119 |
except Exception as e:
|
| 120 |
print(f"數據集加載失敗: {e}")
|
| 121 |
return False
|
| 122 |
-
|
| 123 |
def load_schema(self) -> bool:
|
| 124 |
try:
|
| 125 |
schema_file_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename="sqlite_schema_FULL.json", repo_type='dataset', token=self.hf_token)
|
|
@@ -146,7 +152,7 @@ class RetrievalSystem:
|
|
| 146 |
print(f"正在為 {len(questions)} 個問題計算向量...")
|
| 147 |
self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
|
| 148 |
print("向量計算完成")
|
| 149 |
-
|
| 150 |
def retrieve_similar(self, user_question: str, top_k: int = 1) -> List[Dict]:
|
| 151 |
if self.embedder is None or self.question_embeddings is None: return []
|
| 152 |
try:
|
|
@@ -164,7 +170,7 @@ class CompleteTextToSQLSystem:
|
|
| 164 |
self.data_loader = CompleteDataLoader(hf_token)
|
| 165 |
self.retrieval_system = RetrievalSystem()
|
| 166 |
self.initialize_system()
|
| 167 |
-
|
| 168 |
def initialize_system(self):
|
| 169 |
print("正在初始化完整數據系統...")
|
| 170 |
self.data_loader.load_complete_dataset()
|
|
@@ -182,110 +188,135 @@ class CompleteTextToSQLSystem:
|
|
| 182 |
"""通用SQL生成器 (作為最終備用)"""
|
| 183 |
# 此函數現在作為無法識別具體意圖時的通用後備方案
|
| 184 |
return f"""-- 通用查詢範本
|
| 185 |
-
SELECT
|
| 186 |
-
JobNo as 工作單號,
|
| 187 |
-
ApplicantName as 申請方,
|
| 188 |
-
OverallRating as 評級
|
| 189 |
-
FROM TSR53SampleDescription
|
| 190 |
LIMIT 20;"""
|
| 191 |
|
| 192 |
def intelligent_repair_sql(self, user_question: str, similar_question: str) -> str:
|
| 193 |
"""智能修復SQL - 基於當前使用者問題的意圖"""
|
| 194 |
analysis = analyze_question_type(user_question)
|
| 195 |
intent = analysis["specific_intent"]
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
| 199 |
if intent == "monthly_completion_count":
|
| 200 |
year = self.extract_year(user_question)
|
| 201 |
return comment + f"""-- 查詢 {year} 年每月完成的工作單數量
|
| 202 |
-
SELECT
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
FROM JobTimeline jt
|
| 206 |
-
WHERE strftime('%Y', jt.
|
| 207 |
-
|
| 208 |
-
|
|
|
|
| 209 |
|
| 210 |
elif intent == "rating_distribution":
|
| 211 |
return comment + """-- 查詢評級分佈統計
|
| 212 |
-
SELECT
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
elif intent == "amount_ranking":
|
| 222 |
return comment + """-- 查詢工作單金額排名
|
| 223 |
-
WITH JobTotalAmount AS (
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
elif intent == "company_statistics":
|
| 238 |
return comment + """-- 查詢申請方工作單統計
|
| 239 |
-
SELECT
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
FROM TSR53SampleDescription
|
| 243 |
-
WHERE ApplicantName IS NOT NULL
|
| 244 |
-
GROUP BY ApplicantName
|
| 245 |
-
ORDER BY 工作單數量 DESC
|
| 246 |
-
LIMIT 20;"""
|
| 247 |
-
|
| 248 |
-
#
|
| 249 |
-
return comment +
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
def generate_sql(self, user_question: str) -> Tuple[str, str]:
|
| 252 |
"""主流程:生成SQL查詢 (改進版本)"""
|
| 253 |
-
log_messages = [f"⏰ {get_current_time()}
|
| 254 |
-
|
| 255 |
if not user_question or not user_question.strip():
|
| 256 |
-
return "
|
| 257 |
-
|
| 258 |
-
# 1.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
hits = self.retrieval_system.retrieve_similar(user_question)
|
| 260 |
-
|
| 261 |
if hits:
|
| 262 |
best_hit = hits[0]
|
| 263 |
similarity_score = best_hit['score']
|
| 264 |
-
|
| 265 |
-
|
|
|
|
|
|
|
| 266 |
|
| 267 |
if similarity_score > SIMILARITY_THRESHOLD:
|
| 268 |
-
corpus_id = best_hit['corpus_id']
|
| 269 |
-
similar_question = self.data_loader.questions[corpus_id]
|
| 270 |
original_sql = self.data_loader.sql_answers[corpus_id]
|
| 271 |
-
|
| 272 |
validation = validate_sql(original_sql)
|
|
|
|
| 273 |
if validation["valid"] and validation["is_safe"]:
|
| 274 |
-
log_messages.append("✅
|
| 275 |
return original_sql, "\n".join(log_messages)
|
| 276 |
else:
|
| 277 |
-
log_messages.append(f"⚠️
|
| 278 |
log_messages.append("🛠️ 啟用智能修復...")
|
| 279 |
repaired_sql = self.intelligent_repair_sql(user_question, similar_question)
|
|
|
|
| 280 |
return repaired_sql, "\n".join(log_messages)
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
intelligent_sql = self.intelligent_repair_sql(user_question, "無相似問題")
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
log_messages.append("✅ 智能生成完成。")
|
| 289 |
return intelligent_sql, "\n".join(log_messages)
|
| 290 |
|
| 291 |
# ==================== 初始化系統 ====================
|
|
@@ -307,26 +338,47 @@ def process_query(user_question: str) -> Tuple[str, str, str]:
|
|
| 307 |
with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
|
| 308 |
gr.Markdown("# 🚀 智慧 Text-to-SQL 系統 (進階修復版)")
|
| 309 |
gr.Markdown("📊 **模式**: 結合「檢索驗證」與「意圖導向生成」,即使資料庫範本有誤也能提供準確查詢。")
|
| 310 |
-
|
| 311 |
with gr.Row():
|
| 312 |
-
question_input = gr.Textbox(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
submit_btn = gr.Button("🚀 生成SQL", variant="primary", scale=1)
|
| 314 |
-
|
| 315 |
with gr.Accordion("🔍 結果與日誌", open=True):
|
| 316 |
sql_output = gr.Code(label="📊 生成的SQL查詢", language="sql", lines=10)
|
| 317 |
status_output = gr.Textbox(label="🔍 執行狀態", interactive=False)
|
| 318 |
log_output = gr.Textbox(label="📋 詳細日誌", lines=6, interactive=False)
|
| 319 |
-
|
|
|
|
| 320 |
gr.Examples(
|
| 321 |
examples=[
|
| 322 |
-
"
|
| 323 |
-
"
|
| 324 |
-
"找出總金額最高的
|
| 325 |
-
"
|
|
|
|
|
|
|
| 326 |
],
|
| 327 |
inputs=question_input
|
| 328 |
)
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
if __name__ == "__main__":
|
| 331 |
if text_to_sql_system:
|
| 332 |
print("Gradio 介面啟動中...")
|
|
|
|
| 30 |
"""驗證SQL語句的語法和安全性"""
|
| 31 |
if not sql_query or not sql_query.strip():
|
| 32 |
return {"valid": False, "issues": ["SQL語句為空"], "is_safe": False, "empty": True}
|
| 33 |
+
|
| 34 |
sql_clean = sql_query.strip()
|
| 35 |
if len(sql_clean) < 5:
|
| 36 |
return {"valid": False, "issues": ["SQL過短"], "is_safe": False, "empty": True}
|
| 37 |
+
|
| 38 |
security_issues = []
|
| 39 |
sql_upper = sql_clean.upper()
|
| 40 |
+
|
| 41 |
dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
|
| 42 |
for keyword in dangerous_keywords:
|
| 43 |
if f" {keyword} " in f" {sql_upper} ":
|
| 44 |
security_issues.append(f"危險操作: {keyword}")
|
| 45 |
+
|
| 46 |
if "SELECT" not in sql_upper:
|
| 47 |
security_issues.append("缺少SELECT")
|
| 48 |
if "FROM" not in sql_upper:
|
| 49 |
security_issues.append("缺少FROM")
|
| 50 |
+
|
| 51 |
is_valid = not security_issues
|
| 52 |
is_safe = all('危險' not in issue for issue in security_issues)
|
| 53 |
+
|
| 54 |
return {"valid": is_valid, "issues": security_issues, "is_safe": is_safe, "empty": False}
|
| 55 |
|
| 56 |
def analyze_question_type(question: str) -> Dict:
|
| 57 |
"""增強的問題分析 - 更精確的意圖識別"""
|
| 58 |
question_lower = question.lower()
|
| 59 |
+
|
| 60 |
analysis = {
|
| 61 |
"type": "unknown",
|
| 62 |
"keywords": [],
|
|
|
|
| 65 |
"has_group": "每" in question_lower or "各" in question_lower or "分組" in question_lower,
|
| 66 |
"specific_intent": "general_query" # 新增:具體意圖,預設為通用查詢
|
| 67 |
}
|
| 68 |
+
|
| 69 |
# **更精確的意圖識別**
|
| 70 |
if "每月" in question_lower and ("完成" in question_lower or "報告" in question_lower or "工作單" in question_lower):
|
| 71 |
analysis["specific_intent"] = "monthly_completion_count"
|
|
|
|
| 79 |
elif ("公司" in question_lower or "客戶" in question_lower or "申請方" in question_lower) and ("統計" in question_lower or "數量" in question_lower or "排名" in question_lower):
|
| 80 |
analysis["specific_intent"] = "company_statistics"
|
| 81 |
analysis["type"] = "statistics"
|
| 82 |
+
|
| 83 |
return analysis
|
| 84 |
|
| 85 |
# ==================== 完整數據加載模塊 ====================
|
|
|
|
| 90 |
self.sql_answers = []
|
| 91 |
self.sql_quality = []
|
| 92 |
self.schema_data = {}
|
| 93 |
+
|
| 94 |
def load_complete_dataset(self) -> bool:
|
| 95 |
try:
|
| 96 |
print(f"[{get_current_time()}] 正在加載完整數據集 '{DATASET_REPO_ID}'...")
|
| 97 |
raw_dataset = load_dataset(DATASET_REPO_ID, token=self.hf_token)['train']
|
| 98 |
+
|
| 99 |
+
successful_loads = 0
|
| 100 |
+
total_items = len(raw_dataset)
|
| 101 |
+
|
| 102 |
+
for idx, item in enumerate(raw_dataset):
|
| 103 |
try:
|
| 104 |
if 'messages' in item and len(item['messages']) >= 2:
|
| 105 |
user_content = item['messages'][0]['content']
|
| 106 |
assistant_content = item['messages'][1]['content']
|
| 107 |
+
|
| 108 |
question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
|
| 109 |
question = question_match.group(1).strip() if question_match else user_content
|
| 110 |
+
|
| 111 |
sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
|
| 112 |
sql_query = sql_match.group(1).strip() if sql_match else assistant_content
|
| 113 |
sql_query = re.sub(r'```sql|```', '', sql_query).strip()
|
| 114 |
|
| 115 |
+
if question and sql_query: # 只加載有效的問答對
|
| 116 |
+
self.questions.append(question)
|
| 117 |
+
self.sql_answers.append(sql_query)
|
| 118 |
+
successful_loads += 1
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"跳過第 {idx} 項資料,錯誤: {e}")
|
| 121 |
continue
|
| 122 |
+
|
| 123 |
+
print(f"數據加載完成: 成功載入 {successful_loads}/{total_items} 項")
|
| 124 |
+
return successful_loads > 0
|
| 125 |
except Exception as e:
|
| 126 |
print(f"數據集加載失敗: {e}")
|
| 127 |
return False
|
| 128 |
+
|
| 129 |
def load_schema(self) -> bool:
|
| 130 |
try:
|
| 131 |
schema_file_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename="sqlite_schema_FULL.json", repo_type='dataset', token=self.hf_token)
|
|
|
|
| 152 |
print(f"正在為 {len(questions)} 個問題計算向量...")
|
| 153 |
self.question_embeddings = self.embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
|
| 154 |
print("向量計算完成")
|
| 155 |
+
|
| 156 |
def retrieve_similar(self, user_question: str, top_k: int = 1) -> List[Dict]:
|
| 157 |
if self.embedder is None or self.question_embeddings is None: return []
|
| 158 |
try:
|
|
|
|
| 170 |
self.data_loader = CompleteDataLoader(hf_token)
|
| 171 |
self.retrieval_system = RetrievalSystem()
|
| 172 |
self.initialize_system()
|
| 173 |
+
|
| 174 |
def initialize_system(self):
|
| 175 |
print("正在初始化完整數據系統...")
|
| 176 |
self.data_loader.load_complete_dataset()
|
|
|
|
| 188 |
"""通用SQL生成器 (作為最終備用)"""
|
| 189 |
# 此函數現在作為無法識別具體意圖時的通用後備方案
|
| 190 |
return f"""-- 通用查詢範本
|
| 191 |
+
SELECT
|
| 192 |
+
JobNo as 工作單號,
|
| 193 |
+
ApplicantName as 申請方,
|
| 194 |
+
OverallRating as 評級
|
| 195 |
+
FROM TSR53SampleDescription
|
| 196 |
LIMIT 20;"""
|
| 197 |
|
| 198 |
def intelligent_repair_sql(self, user_question: str, similar_question: str) -> str:
|
| 199 |
"""智能修復SQL - 基於當前使用者問題的意圖"""
|
| 200 |
analysis = analyze_question_type(user_question)
|
| 201 |
intent = analysis["specific_intent"]
|
| 202 |
+
|
| 203 |
+
if similar_question != "無相似問題":
|
| 204 |
+
comment = f"-- 根據類似問題 '{similar_question}' (原SQL無效) 進行智能修復\n"
|
| 205 |
+
else:
|
| 206 |
+
comment = f"-- 根據問題意圖 '{intent}' 智能生成SQL\n"
|
| 207 |
+
|
| 208 |
if intent == "monthly_completion_count":
|
| 209 |
year = self.extract_year(user_question)
|
| 210 |
return comment + f"""-- 查詢 {year} 年每月完成的工作單數量
|
| 211 |
+
SELECT
|
| 212 |
+
strftime('%Y-%m', jt.ReportAuthorization) as 月份,
|
| 213 |
+
COUNT(*) as 完成數量
|
| 214 |
+
FROM JobTimeline jt
|
| 215 |
+
WHERE strftime('%Y', jt.ReportAuthorization) = '{year}'
|
| 216 |
+
AND jt.ReportAuthorization IS NOT NULL
|
| 217 |
+
GROUP BY strftime('%Y-%m', jt.ReportAuthorization)
|
| 218 |
+
ORDER BY 月份;"""
|
| 219 |
|
| 220 |
elif intent == "rating_distribution":
|
| 221 |
return comment + """-- 查詢評級分佈統計
|
| 222 |
+
SELECT
|
| 223 |
+
OverallRating as 評級,
|
| 224 |
+
COUNT(*) as 數量,
|
| 225 |
+
ROUND(COUNT(*) * 100.0 / (
|
| 226 |
+
SELECT COUNT(*)
|
| 227 |
+
FROM TSR53SampleDescription
|
| 228 |
+
WHERE OverallRating IS NOT NULL
|
| 229 |
+
), 2) as 百分比
|
| 230 |
+
FROM TSR53SampleDescription
|
| 231 |
+
WHERE OverallRating IS NOT NULL
|
| 232 |
+
GROUP BY OverallRating
|
| 233 |
+
ORDER BY 數量 DESC;"""
|
| 234 |
|
| 235 |
elif intent == "amount_ranking":
|
| 236 |
return comment + """-- 查詢工作單金額排名
|
| 237 |
+
WITH JobTotalAmount AS (
|
| 238 |
+
SELECT JobNo, SUM(LocalAmount) AS TotalAmount
|
| 239 |
+
FROM (
|
| 240 |
+
SELECT DISTINCT JobNo, InvoiceCreditNoteNo, LocalAmount
|
| 241 |
+
FROM TSR53Invoice
|
| 242 |
+
WHERE LocalAmount IS NOT NULL
|
| 243 |
+
)
|
| 244 |
+
GROUP BY JobNo
|
| 245 |
+
)
|
| 246 |
+
SELECT
|
| 247 |
+
jta.JobNo as 工作單號,
|
| 248 |
+
sd.ApplicantName as 申請方,
|
| 249 |
+
jta.TotalAmount as 總金額
|
| 250 |
+
FROM JobTotalAmount jta
|
| 251 |
+
JOIN TSR53SampleDescription sd ON sd.JobNo = jta.JobNo
|
| 252 |
+
WHERE sd.ApplicantName IS NOT NULL
|
| 253 |
+
ORDER BY jta.TotalAmount DESC
|
| 254 |
+
LIMIT 10;"""
|
| 255 |
|
| 256 |
elif intent == "company_statistics":
|
| 257 |
return comment + """-- 查詢申請方工作單統計
|
| 258 |
+
SELECT
|
| 259 |
+
ApplicantName as 申請方名稱,
|
| 260 |
+
COUNT(*) as 工作單數量
|
| 261 |
+
FROM TSR53SampleDescription
|
| 262 |
+
WHERE ApplicantName IS NOT NULL
|
| 263 |
+
GROUP BY ApplicantName
|
| 264 |
+
ORDER BY 工作單數量 DESC
|
| 265 |
+
LIMIT 20;"""
|
| 266 |
+
|
| 267 |
+
# 通用查詢模板
|
| 268 |
+
return comment + """-- 通用查詢範本
|
| 269 |
+
SELECT
|
| 270 |
+
JobNo as 工作單號,
|
| 271 |
+
ApplicantName as 申請方,
|
| 272 |
+
BuyerName as 買方,
|
| 273 |
+
OverallRating as 評級
|
| 274 |
+
FROM TSR53SampleDescription
|
| 275 |
+
WHERE ApplicantName IS NOT NULL
|
| 276 |
+
LIMIT 20;"""
|
| 277 |
|
| 278 |
def generate_sql(self, user_question: str) -> Tuple[str, str]:
|
| 279 |
"""主流程:生成SQL查詢 (改進版本)"""
|
| 280 |
+
log_messages = [f"⏰ {get_current_time()} 開始處理問題: '{user_question[:50]}...'"]
|
| 281 |
+
|
| 282 |
if not user_question or not user_question.strip():
|
| 283 |
+
return "-- 錯誤: 請輸入有效問題\nSELECT '請輸入您的問題' as 錯誤信息;", "錯誤: 問題為空"
|
| 284 |
+
|
| 285 |
+
# 1. 問題分析
|
| 286 |
+
analysis = analyze_question_type(user_question)
|
| 287 |
+
log_messages.append(f"📋 問題分析 - 意圖: {analysis['specific_intent']}, 類型: {analysis['type']}")
|
| 288 |
+
|
| 289 |
+
# 2. 檢索最相似的問題
|
| 290 |
hits = self.retrieval_system.retrieve_similar(user_question)
|
| 291 |
+
|
| 292 |
if hits:
|
| 293 |
best_hit = hits[0]
|
| 294 |
similarity_score = best_hit['score']
|
| 295 |
+
corpus_id = best_hit['corpus_id']
|
| 296 |
+
similar_question = self.data_loader.questions[corpus_id]
|
| 297 |
+
|
| 298 |
+
log_messages.append(f"🔍 找到相似問題 (相似度: {similarity_score:.3f}): '{similar_question[:50]}...'")
|
| 299 |
|
| 300 |
if similarity_score > SIMILARITY_THRESHOLD:
|
|
|
|
|
|
|
| 301 |
original_sql = self.data_loader.sql_answers[corpus_id]
|
|
|
|
| 302 |
validation = validate_sql(original_sql)
|
| 303 |
+
|
| 304 |
if validation["valid"] and validation["is_safe"]:
|
| 305 |
+
log_messages.append("✅ 相似度高且原SQL有效,直接採用")
|
| 306 |
return original_sql, "\n".join(log_messages)
|
| 307 |
else:
|
| 308 |
+
log_messages.append(f"⚠️ 原SQL有問題: {', '.join(validation['issues'])}")
|
| 309 |
log_messages.append("🛠️ 啟用智能修復...")
|
| 310 |
repaired_sql = self.intelligent_repair_sql(user_question, similar_question)
|
| 311 |
+
log_messages.append("✅ 智能修復完成")
|
| 312 |
return repaired_sql, "\n".join(log_messages)
|
| 313 |
+
else:
|
| 314 |
+
log_messages.append(f"📉 相似度 ({similarity_score:.3f}) 低於閾值 ({SIMILARITY_THRESHOLD})")
|
| 315 |
+
|
| 316 |
+
log_messages.append("🤖 未找到合適範本,使用意圖生成")
|
| 317 |
intelligent_sql = self.intelligent_repair_sql(user_question, "無相似問題")
|
| 318 |
+
log_messages.append("✅ 智能生成完成")
|
| 319 |
+
|
|
|
|
| 320 |
return intelligent_sql, "\n".join(log_messages)
|
| 321 |
|
| 322 |
# ==================== 初始化系統 ====================
|
|
|
|
| 338 |
with gr.Blocks(title="智慧Text-to-SQL系統", theme=gr.themes.Soft()) as demo:
|
| 339 |
gr.Markdown("# 🚀 智慧 Text-to-SQL 系統 (進階修復版)")
|
| 340 |
gr.Markdown("📊 **模式**: 結合「檢索驗證」與「意圖導向生成」,即使資料庫範本有誤也能提供準確查詢。")
|
| 341 |
+
|
| 342 |
with gr.Row():
|
| 343 |
+
question_input = gr.Textbox(
|
| 344 |
+
label="📝 請在此輸入您的問題",
|
| 345 |
+
placeholder="例如:2024年每月完成多少份報告?",
|
| 346 |
+
lines=3,
|
| 347 |
+
scale=4
|
| 348 |
+
)
|
| 349 |
submit_btn = gr.Button("🚀 生成SQL", variant="primary", scale=1)
|
| 350 |
+
|
| 351 |
with gr.Accordion("🔍 結果與日誌", open=True):
|
| 352 |
sql_output = gr.Code(label="📊 生成的SQL查詢", language="sql", lines=10)
|
| 353 |
status_output = gr.Textbox(label="🔍 執行狀態", interactive=False)
|
| 354 |
log_output = gr.Textbox(label="📋 詳細日誌", lines=6, interactive=False)
|
| 355 |
+
|
| 356 |
+
# 改進的範例
|
| 357 |
gr.Examples(
|
| 358 |
examples=[
|
| 359 |
+
"2024年每月完成多少份報告?",
|
| 360 |
+
"統計各種評級(Pass/Fail)的分布情況",
|
| 361 |
+
"找出總金額最高的10個工作單來自哪些申請方",
|
| 362 |
+
"哪些客戶的工作單數量最多?",
|
| 363 |
+
"A組昨天完成了多少個測試項目?",
|
| 364 |
+
"2024年Q1期間評級為Fail且總金額超過10000的工作單"
|
| 365 |
],
|
| 366 |
inputs=question_input
|
| 367 |
)
|
| 368 |
|
| 369 |
+
# 綁定事件
|
| 370 |
+
submit_btn.click(
|
| 371 |
+
fn=process_query,
|
| 372 |
+
inputs=[question_input],
|
| 373 |
+
outputs=[sql_output, status_output, log_output]
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
question_input.submit(
|
| 377 |
+
fn=process_query,
|
| 378 |
+
inputs=[question_input],
|
| 379 |
+
outputs=[sql_output, status_output, log_output]
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
if __name__ == "__main__":
|
| 383 |
if text_to_sql_system:
|
| 384 |
print("Gradio 介面啟動中...")
|