Paul720810 commited on
Commit
2671970
·
verified ·
1 Parent(s): 44209e5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -0
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import json
4
+ import os
5
+ from datasets import load_dataset, Dataset
6
+ from sentence_transformers import SentenceTransformer, util
7
+ import torch
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # --- 配置區 ---
11
+ # 從 Hugging Face Secrets 獲取 Token,這是最安全的方式
12
+ HF_TOKEN = os.environ.get("HF_TOKEN")
13
+ # 您的 Dataset 倉庫 ID
14
+ DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
15
+ # 雲端 LLM 模型的 API URL (推薦使用 CodeLlama-34b,它更強大)
16
+ LLM_API_URL = "https://api-inference.huggingface.co/models/codellama/CodeLlama-34b-Instruct-hf"
17
+
18
+ # 相似度閾值,高於此值則直接返回答案
19
+ SIMILARITY_THRESHOLD = 0.90
20
+
21
+ print("--- [1/5] 開始初始化應用 ---")
22
+
23
+ # --- 1. 載入知識庫 ---
24
+ try:
25
+ print(f"--- [2/5] 正在從 '{DATASET_REPO_ID}' 載入知識庫... ---")
26
+ # 載入問答範例
27
+ dataset = load_dataset(DATASET_REPO_ID, token=HF_TOKEN, trust_remote_code=True)
28
+ qa_dataset = dataset['train']
29
+
30
+ # 載入並解析 Schema JSON
31
+ schema_file_path = "sqlite_schema_FULL.json"
32
+ hf_hub_download(repo_id=DATASET_REPO_ID, filename=schema_file_path, repo_type='dataset', local_dir='.', token=HF_TOKEN)
33
+
34
+ with open(schema_file_path, 'r', encoding='utf-8') as f:
35
+ schema_data = json.load(f)
36
+
37
+ print(f"--- > 成功載入 {len(qa_dataset)} 條問答範例和 Schema。 ---")
38
+ except Exception as e:
39
+ print(f"!!! 致命錯誤: 無法載入 Dataset '{DATASET_REPO_ID}'. 請檢查:")
40
+ print("1. Dataset 倉庫是否設為 Public,或 HF_TOKEN 是否有讀取 Private 倉庫的權限。")
41
+ print("2. 倉庫中是否包含 training_data.jsonl 和 sqlite_schema_FULL.json。")
42
+ print(f"詳細錯誤: {e}")
43
+ # 如果載入失敗,則使用備用數據避免應用崩潰
44
+ qa_dataset = Dataset.from_dict({"question": ["示例問題"], "sql": ["SELECT 'Dataset failed to load'"]})
45
+ schema_data = {}
46
+
47
+ # --- 2. 構建 DDL 和初始化檢索模型 ---
48
+ def load_schema_as_ddl(schema_dict: dict) -> str:
49
+ ddl_string = ""
50
+ for table_name, columns in schema_dict.items():
51
+ if not isinstance(columns, list): continue
52
+ ddl_string += f"CREATE TABLE `{table_name}` (\n"
53
+ ddl_cols = [f" `{col.get('name', '')}` {col.get('type', '')} -- {col.get('description', '')}" for col in columns]
54
+ ddl_string += ",\n".join(ddl_cols) + "\n);\n\n"
55
+ return ddl_string
56
+
57
+ SCHEMA_DDL = load_schema_as_ddl(schema_data)
58
+
59
+ print("--- [3/5] 正在載入句向量模型 (all-MiniLM-L6-v2)... ---")
60
+ # 輕量級句向量模型,在 CPU 上運行極快
61
+ embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
62
+
63
+ questions = [item['question'] for item in qa_dataset]
64
+ print(f"--- [4/5] 正在為 {len(questions)} 個問題計算向量 (這可能需要幾分鐘)... ---")
65
+ # 預先計算所有問題的向量,這是實現快速檢索的關鍵
66
+ question_embeddings = embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
67
+ sql_answers = [item['sql'] for item in qa_dataset]
68
+ print("--- > 向量計算完成! ---")
69
+
70
+ # --- 3. 混合系統核心邏輯 ---
71
+ def get_sql_query(user_question: str):
72
+ if not user_question:
73
+ return "請輸入您的問題。", "日誌:用戶未輸入問題。"
74
+
75
+ # 1. 向量檢索
76
+ question_embedding = embedder.encode(user_question, convert_to_tensor=True)
77
+ hits = util.semantic_search(question_embedding, question_embeddings, top_k=5)
78
+ hits = hits[0] # Get the hits for the first query
79
+
80
+ most_similar_hit = hits[0]
81
+ similarity_score = most_similar_hit['score']
82
+
83
+ log_message = f"檢索到最相似問題: '{questions[most_similar_hit['corpus_id']]}' (相似度: {similarity_score:.4f})"
84
+
85
+ # 2. 如果相似度足夠高,直接返回預定義的 SQL
86
+ if similarity_score > SIMILARITY_THRESHOLD:
87
+ sql_result = sql_answers[most_similar_hit['corpus_id']]
88
+ log_message += f"\n相似度 > {SIMILARITY_THRESHOLD},[模式: 直接返回]。"
89
+ return sql_result, log_message
90
+
91
+ # 3. 否則,檢索幾個相關例子,用 LLM 生成新 SQL
92
+ log_message += f"\n相似度 < {SIMILARITY_THRESHOLD},[模式: LLM生成]。正在構建 Prompt..."
93
+
94
+ # 構建 Prompt
95
+ examples_context = ""
96
+ for hit in hits[:3]: # 取最相關的3個例子
97
+ examples_context += f"### A user asks: {questions[hit['corpus_id']]}\n{sql_answers[hit['corpus_id']]}\n\n"
98
+
99
+ prompt = f"""### Task
100
+ Generate a SQLite SQL query that answers the following user question.
101
+ Your response must contain ONLY the SQL query. Do not add any explanation.
102
+
103
+ ### Database Schema
104
+ {SCHEMA_DDL}
105
+ ### Examples
106
+ {examples_context}
107
+ ### Question
108
+ {user_question}
109
+
110
+ ### SQL Query
111
+ """
112
+
113
+ # 調用 Hugging Face Inference API
114
+ log_message += "\n正在請求雲端 LLM..."
115
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
116
+ payload = {"inputs": prompt, "parameters": {"max_new_tokens": 512, "temperature": 0.1, "return_full_text": False}}
117
+ response_text = ""
118
+
119
+ try:
120
+ response = requests.post(LLM_API_URL, headers=headers, json=payload)
121
+ response_text = response.text # 先保存原始響應文本
122
+ response.raise_for_status()
123
+
124
+ generated_text = response.json()[0]['generated_text'].strip()
125
+
126
+ # 清理常見的返回格式問題
127
+ if "```sql" in generated_text:
128
+ generated_text = generated_text.split("```sql")[1].split("```").strip()
129
+ if "```" in generated_text:
130
+ generated_text = generated_text.replace("```", "").strip()
131
+
132
+
133
+ log_message += f"\nLLM 生成成功!"
134
+ return generated_text, log_message
135
+ except Exception as e:
136
+ error_msg = f"LLM API 調用失敗: {e}\nAPI 原始回應: {response_text}"
137
+ log_message += f"\n{error_msg}"
138
+ return "抱歉,調用雲端 AI 時發生錯誤。", log_message
139
+
140
+ # --- 4. 創建 Gradio Web 界面 ---
141
+ print("--- [5/5] 正在創建 Gradio Web 界面... ---")
142
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
143
+ gr.Markdown("# 智能 Text-to-SQL 系統 (混合模式)")
144
+ gr.Markdown("輸入您的自然語言問題,系統將首先嘗試從知識庫中快速檢索答案。如果問題較新穎,則會調用雲端大語言模型生成SQL。")
145
+
146
+ with gr.Row():
147
+ question_input = gr.Textbox(label="輸入您的問題", placeholder="例如:去年Nike的總業績是多少?", scale=4)
148
+ submit_button = gr.Button("生成SQL", variant="primary", scale=1)
149
+
150
+ sql_output = gr.Code(label="生成的 SQL 查詢", language="sql")
151
+ log_output = gr.Textbox(label="系統日誌 (執行過程)", lines=4, interactive=False)
152
+
153
+ submit_button.click(
154
+ fn=get_sql_query,
155
+ inputs=question_input,
156
+ outputs=[sql_output, log_output]
157
+ )
158
+
159
+ gr.Examples(
160
+ examples=[
161
+ "2024 最好的5個客人以及業績",
162
+ "比較2023年跟2024年的業績",
163
+ "上禮拜C組 完成幾份報告",
164
+ "有沒有快到期的單子?",
165
+ "哪個客戶的付款最不及時?"
166
+ ],
167
+ inputs=question_input
168
+ )
169
+
170
+ print("--- 應用準備啟動 ---")
171
+ demo.launch()