Paul720810 commited on
Commit
7b1b963
·
verified ·
1 Parent(s): fcd6422

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -104
app.py CHANGED
@@ -6,11 +6,13 @@ from datasets import load_dataset, Dataset
6
  from sentence_transformers import SentenceTransformer, util
7
  import torch
8
  from huggingface_hub import hf_hub_download
 
9
 
10
  # --- 配置區 ---
11
  HF_TOKEN = os.environ.get("HF_TOKEN")
12
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
13
- LLM_API_URL = "https://api-inference.huggingface.co/models/codellama/CodeLlama-34b-Instruct-hf"
 
14
  SIMILARITY_THRESHOLD = 0.90
15
 
16
  print("--- [1/5] 開始初始化應用 ---")
@@ -18,57 +20,90 @@ print("--- [1/5] 開始初始化應用 ---")
18
  # --- 1. 載入知識庫 ---
19
  qa_dataset = None
20
  schema_data = {}
 
 
 
21
  try:
22
  print(f"--- [2/5] 正在從 '{DATASET_REPO_ID}' 載入知識庫... ---")
23
  raw_dataset = load_dataset(DATASET_REPO_ID, token=HF_TOKEN)['train']
24
 
25
- # *** 關鍵修正:智能解析 Dataset ***
26
- # 檢查第一條數據的結構來判斷格式
27
- if raw_dataset and len(raw_dataset) > 0:
28
- first_item = raw_dataset[0]
29
- if 'text' in first_item and 'question' not in first_item:
30
- # 這是舊的 {'text': '...'} 格式,需要解析
31
- print("--- > 檢測到 'text' 格式,正在解析JSON...")
32
- parsed_qa_data = []
33
- for item in raw_dataset:
34
- try:
35
- line_dict = json.loads(item['text'])
36
- parsed_qa_data.append(line_dict)
37
- except (json.JSONDecodeError, KeyError):
38
- continue # 跳過錯誤行
39
- qa_dataset = Dataset.from_list(parsed_qa_data)
40
- elif 'question' in first_item and 'sql' in first_item:
41
- # 這已經是正確的 {'question': ..., 'sql': ...} 格式
42
- print("--- > 檢測到已解析的 'question'/'sql' 格式,直接使用。")
43
- qa_dataset = raw_dataset
44
- else:
45
- raise ValueError(f"未知的Dataset格式: {first_item}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  else:
47
- # 數據集為空
48
- raise ValueError("載入的Dataset為空。")
49
-
50
  # 載入並解析 Schema JSON
51
  schema_file_path = "sqlite_schema_FULL.json"
52
- hf_hub_download(repo_id=DATASET_REPO_ID, filename=schema_file_path, repo_type='dataset', local_dir='.', token=HF_TOKEN)
53
-
54
- with open(schema_file_path, 'r', encoding='utf-8') as f:
55
- schema_data = json.load(f)
 
 
 
 
 
56
 
57
- print(f"--- > 成功載入 {len(qa_dataset)} 條問答範例和 Schema。 ---")
58
 
59
  except Exception as e:
60
- print(f"!!! 致命錯誤: 無法載入或解析 Dataset '{DATASET_REPO_ID}'.")
61
- print(f"詳細錯誤: {e}")
62
- qa_dataset = Dataset.from_dict({"question": ["示例問題"], "sql": ["SELECT 'Dataset failed to load'"]})
 
 
63
 
64
  # --- 2. 構建 DDL 和初始化檢索模型 ---
65
  def load_schema_as_ddl(schema_dict: dict) -> str:
66
- # (此函式無需修改)
67
- ddl_string = ""
68
  for table_name, columns in schema_dict.items():
69
- if not isinstance(columns, list): continue
 
70
  ddl_string += f"CREATE TABLE `{table_name}` (\n"
71
- ddl_cols = [f" `{col.get('name', '')}` {col.get('type', '')} -- {col.get('description', '')}" for col in columns]
 
 
 
 
 
72
  ddl_string += ",\n".join(ddl_cols) + "\n);\n\n"
73
  return ddl_string
74
 
@@ -77,113 +112,156 @@ SCHEMA_DDL = load_schema_as_ddl(schema_data)
77
  print("--- [3/5] 正在載入句向量模型 (all-MiniLM-L6-v2)... ---")
78
  embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
79
 
80
- questions = [item['question'] for item in qa_dataset]
81
- sql_answers = [item['sql'] for item in qa_dataset]
82
-
83
- # 只有在 questions 列表不為空時才進行計算
84
  if questions:
85
  print(f"--- [4/5] 正在為 {len(questions)} 個問題計算向量... ---")
86
  question_embeddings = embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
87
  print("--- > 向量計算完成! ---")
88
  else:
89
- print("--- [4/5] 警告:沒有可用的問題來計算向量。檢索功能將不可用。---")
90
  question_embeddings = torch.Tensor([])
91
 
92
-
93
  # --- 3. 混合系統核心邏輯 ---
94
  def get_sql_query(user_question: str):
95
- # (此函式剩餘部分幾乎無需修改)
96
  if not user_question:
97
  return "請輸入您的問題。", "日誌:用戶未輸入問題。"
98
 
99
- # 增加一個檢查,確保知識庫不是空的
100
- if len(questions) == 0:
101
- log_message = "錯誤:知識庫為空,無法進行檢索。"
102
- return "系統錯誤:知識庫未成功載入。", log_message
103
-
104
- question_embedding = embedder.encode(user_question, convert_to_tensor=True)
105
- hits = util.semantic_search(question_embedding, question_embeddings, top_k=5)
106
 
107
- if not hits or not hits[0]:
108
- log_message = "檢索失敗:找不到任何相似的問題。"
109
- # 即使檢索失敗,也應該嘗試調用 LLM
110
- else:
111
- hits = hits[0]
112
- most_similar_hit = hits[0]
113
- similarity_score = most_similar_hit['score']
114
- log_message = f"檢索到最相似問題: '{questions[most_similar_hit['corpus_id']]}' (相似度: {similarity_score:.4f})"
115
 
116
- if similarity_score > SIMILARITY_THRESHOLD:
117
- sql_result = sql_answers[most_similar_hit['corpus_id']]
118
- log_message += f"\n相似度 > {SIMILARITY_THRESHOLD},[模式: 直接返回]。"
119
- return sql_result, log_message
120
-
121
- log_message += f"\n相似度低於閾值或檢索失敗,[模式: LLM生成]。正在構建 Prompt..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  examples_context = ""
123
- if hits: # 只有在檢索到結果時才添加範例
124
- for hit in hits[:3]:
125
- examples_context += f"### A user asks: {questions[hit['corpus_id']]}\n{sql_answers[hit['corpus_id']]}\n\n"
126
-
127
- prompt = f"""### Task
128
- Generate a SQLite SQL query that answers the following user question.
129
- Your response must contain ONLY the SQL query. Do not add any explanation.
130
 
131
- ### Database Schema
132
  {SCHEMA_DDL}
133
- ### Examples
 
134
  {examples_context}
135
- ### Question
 
136
  {user_question}
137
 
138
- ### SQL Query
 
139
  """
140
- log_message += "\n正在請求雲端 LLM..."
 
 
141
  headers = {"Authorization": f"Bearer {HF_TOKEN}"}
142
- payload = {"inputs": prompt, "parameters": {"max_new_tokens": 512, "temperature": 0.1, "return_full_text": False}}
143
- response_text = ""
 
 
 
 
 
 
 
144
  try:
145
- response = requests.post(LLM_API_URL, headers=headers, json=payload)
146
- response_text = response.text
147
- response.raise_for_status()
148
- generated_text = response.json()[0]['generated_text'].strip()
149
- if "```sql" in generated_text:
150
- generated_text = generated_text.split("```sql")[1].split("```").strip()
151
- if "```" in generated_text:
152
- generated_text = generated_text.replace("```", "").strip()
153
- log_message += f"\nLLM 生成成功!"
154
- return generated_text, log_message
 
 
 
 
 
 
 
 
 
 
155
  except Exception as e:
156
- error_msg = f"LLM API 調用失敗: {e}\nAPI 原始回應: {response_text}"
157
- log_message += f"\n{error_msg}"
158
- return "抱歉,調用雲端 AI 時發生錯誤。", log_message
 
 
 
159
 
160
  # --- 4. 創建 Gradio Web 界面 ---
161
  print("--- [5/5] 正在創建 Gradio Web 界面... ---")
162
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
163
- # (此部分無需修改)
164
- gr.Markdown("# 智能 Text-to-SQL 系統 (混合模式)")
165
- # ... (Gradio界面代碼與之前相同)
166
- gr.Markdown("輸入您的自然語言問題,系統將首先嘗試從知識庫中快速檢索答案。如果問題較新穎,則會調用雲端大語言模型生成SQL。")
167
  with gr.Row():
168
- question_input = gr.Textbox(label="輸入您的問題", placeholder="例如:去年Nike的總業績是多少?", scale=4)
 
 
 
 
 
169
  submit_button = gr.Button("生成SQL", variant="primary", scale=1)
170
- sql_output = gr.Code(label="生成的 SQL 查詢", language="sql")
171
- log_output = gr.Textbox(label="系統日誌 (執行過程)", lines=4, interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  submit_button.click(
173
  fn=get_sql_query,
174
  inputs=question_input,
175
  outputs=[sql_output, log_output]
176
  )
 
177
  gr.Examples(
178
  examples=[
179
- "2024 最好的5個客人以及業績",
180
  "比較2023年跟2024年的業績",
181
- "上禮拜C組 完成幾份報告",
182
- "有沒��快到期的單子?",
183
  "哪個客戶的付款最不及時?"
184
  ],
185
- inputs=question_input
 
186
  )
187
 
188
  print("--- 應用準備啟動 ---")
189
- demo.launch()
 
 
6
  from sentence_transformers import SentenceTransformer, util
7
  import torch
8
  from huggingface_hub import hf_hub_download
9
+ import re
10
 
11
  # --- 配置區 ---
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
14
+ # 使用正確的模型名稱(7B版本更適合免費使用)
15
+ LLM_API_URL = "https://api-inference.huggingface.co/models/codellama/CodeLlama-7b-hf"
16
  SIMILARITY_THRESHOLD = 0.90
17
 
18
  print("--- [1/5] 開始初始化應用 ---")
 
20
  # --- 1. 載入知識庫 ---
21
  qa_dataset = None
22
  schema_data = {}
23
+ questions = []
24
+ sql_answers = []
25
+
26
  try:
27
  print(f"--- [2/5] 正在從 '{DATASET_REPO_ID}' 載入知識庫... ---")
28
  raw_dataset = load_dataset(DATASET_REPO_ID, token=HF_TOKEN)['train']
29
 
30
+ # 解析新的 messages 格式
31
+ print("--- > 檢測到 'messages' 格式,正在解析...")
32
+
33
+ for item in raw_dataset:
34
+ try:
35
+ if 'messages' in item and len(item['messages']) >= 2:
36
+ user_content = item['messages'][0]['content']
37
+ assistant_content = item['messages'][1]['content']
38
+
39
+ # 從用戶消息中提取問題
40
+ question_match = re.search(r'指令:\s*(.*?)(?:\n|$)', user_content)
41
+ if question_match:
42
+ question = question_match.group(1).strip()
43
+ else:
44
+ # 如果沒有找到指令,使用整個內容
45
+ question = user_content
46
+
47
+ # 從助手消息中提取SQL
48
+ sql_match = re.search(r'SQL查詢:\s*(.*?)(?:\n|$)', assistant_content, re.DOTALL)
49
+ if sql_match:
50
+ sql_query = sql_match.group(1).strip()
51
+ # 清理SQL語句
52
+ sql_query = re.sub(r'^sql\s*', '', sql_query) # 移除開頭的sql
53
+ sql_query = re.sub(r'```sql|```', '', sql_query).strip() # 移除代碼塊標記
54
+ else:
55
+ sql_query = assistant_content
56
+
57
+ questions.append(question)
58
+ sql_answers.append(sql_query)
59
+
60
+ except (KeyError, IndexError, TypeError) as e:
61
+ print(f"解析錯誤,跳過該條目: {e}")
62
+ continue
63
+
64
+ # 創建問答數據集
65
+ if questions:
66
+ qa_dataset = Dataset.from_dict({
67
+ 'question': questions,
68
+ 'sql': sql_answers
69
+ })
70
  else:
71
+ raise ValueError("沒有成功解析出任何問答對")
72
+
 
73
  # 載入並解析 Schema JSON
74
  schema_file_path = "sqlite_schema_FULL.json"
75
+ try:
76
+ hf_hub_download(repo_id=DATASET_REPO_ID, filename=schema_file_path,
77
+ repo_type='dataset', local_dir='.', token=HF_TOKEN)
78
+
79
+ with open(schema_file_path, 'r', encoding='utf-8') as f:
80
+ schema_data = json.load(f)
81
+ except Exception as e:
82
+ print(f"警告: 無法載入Schema文件: {e}")
83
+ schema_data = {}
84
 
85
+ print(f"--- > 成功解析 {len(questions)} 條問答範例。 ---")
86
 
87
  except Exception as e:
88
+ print(f"!!! 錯誤: 處理Dataset時發生問題: {e}")
89
+ # 創建備用數據集
90
+ questions = ["示例問題"]
91
+ sql_answers = ["SELECT '請檢查數據集格式' AS error;"]
92
+ qa_dataset = Dataset.from_dict({"question": questions, "sql": sql_answers})
93
 
94
  # --- 2. 構建 DDL 和初始化檢索模型 ---
95
  def load_schema_as_ddl(schema_dict: dict) -> str:
96
+ ddl_string = "/* 數據庫結構 */\n"
 
97
  for table_name, columns in schema_dict.items():
98
+ if not isinstance(columns, list):
99
+ continue
100
  ddl_string += f"CREATE TABLE `{table_name}` (\n"
101
+ ddl_cols = []
102
+ for col in columns:
103
+ col_name = col.get('name', 'unknown')
104
+ col_type = col.get('type', 'TEXT')
105
+ col_desc = col.get('description', '')
106
+ ddl_cols.append(f" `{col_name}` {col_type} -- {col_desc}")
107
  ddl_string += ",\n".join(ddl_cols) + "\n);\n\n"
108
  return ddl_string
109
 
 
112
  print("--- [3/5] 正在載入句向量模型 (all-MiniLM-L6-v2)... ---")
113
  embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
114
 
115
+ # 計算問題向量
 
 
 
116
  if questions:
117
  print(f"--- [4/5] 正在為 {len(questions)} 個問題計算向量... ---")
118
  question_embeddings = embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
119
  print("--- > 向量計算完成! ---")
120
  else:
121
+ print("--- [4/5] 警告:沒有可用的問題來計算向量。 ---")
122
  question_embeddings = torch.Tensor([])
123
 
 
124
  # --- 3. 混合系統核心邏輯 ---
125
  def get_sql_query(user_question: str):
 
126
  if not user_question:
127
  return "請輸入您的問題。", "日誌:用戶未輸入問題。"
128
 
129
+ log_messages = []
 
 
 
 
 
 
130
 
131
+ # 檢索相似問題
132
+ if len(questions) > 0:
133
+ question_embedding = embedder.encode(user_question, convert_to_tensor=True)
134
+ hits = util.semantic_search(question_embedding, question_embeddings, top_k=3)
 
 
 
 
135
 
136
+ if hits and hits[0]:
137
+ most_similar_hit = hits[0][0]
138
+ similarity_score = most_similar_hit['score']
139
+ similar_question = questions[most_similar_hit['corpus_id']]
140
+
141
+ log_messages.append(f"檢索到相似問題: '{similar_question}' (相似度: {similarity_score:.4f})")
142
+
143
+ if similarity_score > SIMILARITY_THRESHOLD:
144
+ sql_result = sql_answers[most_similar_hit['corpus_id']]
145
+ log_messages.append(f"相似度 > {SIMILARITY_THRESHOLD},直接返回預先SQL")
146
+ return sql_result, "\n".join(log_messages)
147
+ else:
148
+ log_messages.append("檢索失敗:找不到相似問題")
149
+ else:
150
+ log_messages.append("知識庫為空,跳過檢索")
151
+
152
+ # LLM生成模式
153
+ log_messages.append("進入LLM生成模式...")
154
+
155
+ # 構建示例上下文
156
  examples_context = ""
157
+ if 'hits' in locals() and hits and hits[0]:
158
+ for i, hit in enumerate(hits[0][:2]):
159
+ examples_context += f"問題: {questions[hit['corpus_id']]}\nSQL: {sql_answers[hit['corpus_id']]}\n\n"
160
+
161
+ # 構建提示詞
162
+ prompt = f"""你是一個SQL專家。請根據數據庫結構生成SQL查詢。
 
163
 
164
+ 數據庫結構:
165
  {SCHEMA_DDL}
166
+
167
+ 參考示例:
168
  {examples_context}
169
+
170
+ 請為以下問題生成SQL查詢:
171
  {user_question}
172
 
173
+ 只輸出SQL語句,不要其他內容:
174
+
175
  """
176
+
177
+ log_messages.append("正在請求雲端LLM...")
178
+
179
  headers = {"Authorization": f"Bearer {HF_TOKEN}"}
180
+ payload = {
181
+ "inputs": prompt,
182
+ "parameters": {
183
+ "max_new_tokens": 300,
184
+ "temperature": 0.1,
185
+ "do_sample": False
186
+ }
187
+ }
188
+
189
  try:
190
+ response = requests.post(LLM_API_URL, headers=headers, json=payload, timeout=30)
191
+
192
+ if response.status_code == 200:
193
+ result = response.json()
194
+ if isinstance(result, list) and len(result) > 0:
195
+ generated_text = result[0]['generated_text'].strip()
196
+
197
+ # 清理輸出,只保留SQL
198
+ if "```sql" in generated_text:
199
+ generated_text = generated_text.split("```sql")[1].split("```")[0].strip()
200
+ elif "```" in generated_text:
201
+ generated_text = generated_text.split("```")[1].strip() if len(generated_text.split("```")) > 2 else generated_text
202
+
203
+ log_messages.append("LLM生成成功!")
204
+ return generated_text, "\n".join(log_messages)
205
+ else:
206
+ raise Exception(f"API返回格式異常: {result}")
207
+ else:
208
+ raise Exception(f"API錯誤: {response.status_code} - {response.text}")
209
+
210
  except Exception as e:
211
+ error_msg = f"LLM API調用失敗: {str(e)}"
212
+ log_messages.append(error_msg)
213
+
214
+ # 提供備用答案
215
+ backup_sql = "SELECT 'AI服務暫時不可用,請稍後重試' AS status;"
216
+ return backup_sql, "\n".join(log_messages)
217
 
218
  # --- 4. 創建 Gradio Web 界面 ---
219
  print("--- [5/5] 正在創建 Gradio Web 界面... ---")
220
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
221
+ gr.Markdown("# 🚀 智能 Text-to-SQL 系統 (混合模式)")
222
+ gr.Markdown("輸入自然語言問題,系統會智能生成SQL查詢")
223
+
 
224
  with gr.Row():
225
+ question_input = gr.Textbox(
226
+ label="輸入您的問題",
227
+ placeholder="例如:查詢去年的銷售總額",
228
+ lines=2,
229
+ scale=4
230
+ )
231
  submit_button = gr.Button("生成SQL", variant="primary", scale=1)
232
+
233
+ with gr.Row():
234
+ sql_output = gr.Code(
235
+ label="生成的 SQL 查詢",
236
+ language="sql",
237
+ lines=6
238
+ )
239
+
240
+ with gr.Row():
241
+ log_output = gr.Textbox(
242
+ label="系統日誌",
243
+ lines=4,
244
+ interactive=False
245
+ )
246
+
247
  submit_button.click(
248
  fn=get_sql_query,
249
  inputs=question_input,
250
  outputs=[sql_output, log_output]
251
  )
252
+
253
  gr.Examples(
254
  examples=[
255
+ "2024年最好的5個客戶以及業績",
256
  "比較2023年跟2024年的業績",
257
+ "上週C組完成了幾份報告",
258
+ "有沒有快到期的訂單?",
259
  "哪個客戶的付款最不及時?"
260
  ],
261
+ inputs=question_input,
262
+ label="示例問題"
263
  )
264
 
265
  print("--- 應用準備啟動 ---")
266
+ if __name__ == "__main__":
267
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)