Paul720810 commited on
Commit
5df0985
·
verified ·
1 Parent(s): 1cb3792

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -42
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import gradio as gr
2
  import requests
3
  import json
@@ -8,14 +11,9 @@ import torch
8
  from huggingface_hub import hf_hub_download
9
 
10
  # --- 配置區 ---
11
- # 從 Hugging Face Secrets 獲取 Token,這是最安全的方式
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
- # 您的 Dataset 倉庫 ID
14
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
15
- # 雲端 LLM 模型的 API URL (推薦使用 CodeLlama-34b,它更強大)
16
  LLM_API_URL = "https://api-inference.huggingface.co/models/codellama/CodeLlama-34b-Instruct-hf"
17
-
18
- # 相似度閾值,高於此值則直接返回答案
19
  SIMILARITY_THRESHOLD = 0.90
20
 
21
  print("--- [1/5] 開始初始化應用 ---")
@@ -23,9 +21,22 @@ print("--- [1/5] 開始初始化應用 ---")
23
  # --- 1. 載入知識庫 ---
24
  try:
25
  print(f"--- [2/5] 正在從 '{DATASET_REPO_ID}' 載入知識庫... ---")
26
- # 載入問答範例
27
- dataset = load_dataset(DATASET_REPO_ID, token=HF_TOKEN, trust_remote_code=True)
28
- qa_dataset = dataset['train']
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # 載入並解析 Schema JSON
31
  schema_file_path = "sqlite_schema_FULL.json"
@@ -34,13 +45,10 @@ try:
34
  with open(schema_file_path, 'r', encoding='utf-8') as f:
35
  schema_data = json.load(f)
36
 
37
- print(f"--- > 成功載入 {len(qa_dataset)} 條問答範例和 Schema。 ---")
38
  except Exception as e:
39
- print(f"!!! 致命錯誤: 無法載入 Dataset '{DATASET_REPO_ID}'. 請檢查:")
40
- print("1. Dataset 倉庫是否設為 Public,或 HF_TOKEN 是否有讀取 Private 倉庫的權限。")
41
- print("2. 倉庫中是否包含 training_data.jsonl 和 sqlite_schema_FULL.json。")
42
  print(f"詳細錯誤: {e}")
43
- # 如果載入失敗,則使用備用數據避免應用崩潰
44
  qa_dataset = Dataset.from_dict({"question": ["示例問題"], "sql": ["SELECT 'Dataset failed to load'"]})
45
  schema_data = {}
46
 
@@ -57,45 +65,35 @@ def load_schema_as_ddl(schema_dict: dict) -> str:
57
  SCHEMA_DDL = load_schema_as_ddl(schema_data)
58
 
59
  print("--- [3/5] 正在載入句向量模型 (all-MiniLM-L6-v2)... ---")
60
- # 輕量級句向量模型,在 CPU 上運行極快
61
  embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
62
 
 
63
  questions = [item['question'] for item in qa_dataset]
 
 
64
  print(f"--- [4/5] 正在為 {len(questions)} 個問題計算向量 (這可能需要幾分鐘)... ---")
65
- # 預先計算所有問題的向量,這是實現快速檢索的關鍵
66
  question_embeddings = embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
67
- sql_answers = [item['sql'] for item in qa_dataset]
68
  print("--- > 向量計算完成! ---")
69
 
70
  # --- 3. 混合系統核心邏輯 ---
71
  def get_sql_query(user_question: str):
 
72
  if not user_question:
73
  return "請輸入您的問題。", "日誌:用戶未輸入問題。"
74
-
75
- # 1. 向量檢索
76
  question_embedding = embedder.encode(user_question, convert_to_tensor=True)
77
  hits = util.semantic_search(question_embedding, question_embeddings, top_k=5)
78
- hits = hits[0] # Get the hits for the first query
79
-
80
  most_similar_hit = hits[0]
81
  similarity_score = most_similar_hit['score']
82
-
83
  log_message = f"檢索到最相似問題: '{questions[most_similar_hit['corpus_id']]}' (相似度: {similarity_score:.4f})"
84
-
85
- # 2. 如果相似度足夠高,直接返回預定義的 SQL
86
  if similarity_score > SIMILARITY_THRESHOLD:
87
  sql_result = sql_answers[most_similar_hit['corpus_id']]
88
  log_message += f"\n相似度 > {SIMILARITY_THRESHOLD},[模式: 直接返回]。"
89
  return sql_result, log_message
90
-
91
- # 3. 否則,檢索幾個相關例子,用 LLM 生成新 SQL
92
  log_message += f"\n相似度 < {SIMILARITY_THRESHOLD},[模式: LLM生成]。正在構建 Prompt..."
93
-
94
- # 構建 Prompt
95
  examples_context = ""
96
- for hit in hits[:3]: # 取最相關的3個例子
97
  examples_context += f"### A user asks: {questions[hit['corpus_id']]}\n{sql_answers[hit['corpus_id']]}\n\n"
98
-
99
  prompt = f"""### Task
100
  Generate a SQLite SQL query that answers the following user question.
101
  Your response must contain ONLY the SQL query. Do not add any explanation.
@@ -109,27 +107,19 @@ Your response must contain ONLY the SQL query. Do not add any explanation.
109
 
110
  ### SQL Query
111
  """
112
-
113
- # 調用 Hugging Face Inference API
114
  log_message += "\n正在請求雲端 LLM..."
115
  headers = {"Authorization": f"Bearer {HF_TOKEN}"}
116
  payload = {"inputs": prompt, "parameters": {"max_new_tokens": 512, "temperature": 0.1, "return_full_text": False}}
117
  response_text = ""
118
-
119
  try:
120
  response = requests.post(LLM_API_URL, headers=headers, json=payload)
121
- response_text = response.text # 先保存原始響應文本
122
  response.raise_for_status()
123
-
124
  generated_text = response.json()[0]['generated_text'].strip()
125
-
126
- # 清理常見的返回格式問題
127
  if "```sql" in generated_text:
128
  generated_text = generated_text.split("```sql")[1].split("```").strip()
129
  if "```" in generated_text:
130
  generated_text = generated_text.replace("```", "").strip()
131
-
132
-
133
  log_message += f"\nLLM 生成成功!"
134
  return generated_text, log_message
135
  except Exception as e:
@@ -140,22 +130,19 @@ Your response must contain ONLY the SQL query. Do not add any explanation.
140
  # --- 4. 創建 Gradio Web 界面 ---
141
  print("--- [5/5] 正在創建 Gradio Web 界面... ---")
142
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
143
  gr.Markdown("# 智能 Text-to-SQL 系統 (混合模式)")
144
  gr.Markdown("輸入您的自然語言問題,系統將首先嘗試從知識庫中快速檢索答案。如果問題較新穎,則會調用雲端大語言模型生成SQL。")
145
-
146
  with gr.Row():
147
  question_input = gr.Textbox(label="輸入您的問題", placeholder="例如:去年Nike的總業績是多少?", scale=4)
148
  submit_button = gr.Button("生成SQL", variant="primary", scale=1)
149
-
150
  sql_output = gr.Code(label="生成的 SQL 查詢", language="sql")
151
  log_output = gr.Textbox(label="系統日誌 (執行過程)", lines=4, interactive=False)
152
-
153
  submit_button.click(
154
  fn=get_sql_query,
155
  inputs=question_input,
156
  outputs=[sql_output, log_output]
157
  )
158
-
159
  gr.Examples(
160
  examples=[
161
  "2024 最好的5個客人以及業績",
 
1
+ # 檔案名稱: app.py
2
+ # 部署在 Hugging Face Spaces (已修正 KeyError)
3
+
4
  import gradio as gr
5
  import requests
6
  import json
 
11
  from huggingface_hub import hf_hub_download
12
 
13
  # --- 配置區 ---
 
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
15
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
 
16
  LLM_API_URL = "https://api-inference.huggingface.co/models/codellama/CodeLlama-34b-Instruct-hf"
 
 
17
  SIMILARITY_THRESHOLD = 0.90
18
 
19
  print("--- [1/5] 開始初始化應用 ---")
 
21
  # --- 1. 載入知識庫 ---
22
  try:
23
  print(f"--- [2/5] 正在從 '{DATASET_REPO_ID}' 載入知識庫... ---")
24
+ # 載入問答範例, 移除已過時的 trust_remote_code 參數
25
+ dataset = load_dataset(DATASET_REPO_ID, token=HF_TOKEN)
26
+ raw_qa_dataset = dataset['train']
27
+
28
+ # *** 關鍵修正:解析被包裹在 'text' 欄位中的 JSON ***
29
+ parsed_qa_data = []
30
+ for item in raw_qa_dataset:
31
+ try:
32
+ # item 現在是 {'text': '{"question": "...", "sql": "..."}'}
33
+ line_dict = json.loads(item['text'])
34
+ parsed_qa_data.append(line_dict)
35
+ except (json.JSONDecodeError, KeyError) as e:
36
+ print(f"警告:跳過一行無法解析的數據: {item}, 錯誤: {e}")
37
+
38
+ # 使用解析後的數據創建一個新的、格式正確的 Dataset 對象
39
+ qa_dataset = Dataset.from_list(parsed_qa_data)
40
 
41
  # 載入並解析 Schema JSON
42
  schema_file_path = "sqlite_schema_FULL.json"
 
45
  with open(schema_file_path, 'r', encoding='utf-8') as f:
46
  schema_data = json.load(f)
47
 
48
+ print(f"--- > 成功解析 {len(qa_dataset)} 條問答範例和 Schema。 ---")
49
  except Exception as e:
50
+ print(f"!!! 致命錯誤: 無法載入或解析 Dataset '{DATASET_REPO_ID}'.")
 
 
51
  print(f"詳細錯誤: {e}")
 
52
  qa_dataset = Dataset.from_dict({"question": ["示例問題"], "sql": ["SELECT 'Dataset failed to load'"]})
53
  schema_data = {}
54
 
 
65
  SCHEMA_DDL = load_schema_as_ddl(schema_data)
66
 
67
  print("--- [3/5] 正在載入句向量模型 (all-MiniLM-L6-v2)... ---")
 
68
  embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
69
 
70
+ # *** 關鍵修正:現在 qa_dataset 的結構是正確的了 ***
71
  questions = [item['question'] for item in qa_dataset]
72
+ sql_answers = [item['sql'] for item in qa_dataset]
73
+
74
  print(f"--- [4/5] 正在為 {len(questions)} 個問題計算向量 (這可能需要幾分鐘)... ---")
 
75
  question_embeddings = embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
 
76
  print("--- > 向量計算完成! ---")
77
 
78
  # --- 3. 混合系統核心邏輯 ---
79
  def get_sql_query(user_question: str):
80
+ # (此函式剩餘部分無需修改,保持原樣)
81
  if not user_question:
82
  return "請輸入您的問題。", "日誌:用戶未輸入問題。"
 
 
83
  question_embedding = embedder.encode(user_question, convert_to_tensor=True)
84
  hits = util.semantic_search(question_embedding, question_embeddings, top_k=5)
85
+ hits = hits[0]
 
86
  most_similar_hit = hits[0]
87
  similarity_score = most_similar_hit['score']
 
88
  log_message = f"檢索到最相似問題: '{questions[most_similar_hit['corpus_id']]}' (相似度: {similarity_score:.4f})"
 
 
89
  if similarity_score > SIMILARITY_THRESHOLD:
90
  sql_result = sql_answers[most_similar_hit['corpus_id']]
91
  log_message += f"\n相似度 > {SIMILARITY_THRESHOLD},[模式: 直接返回]。"
92
  return sql_result, log_message
 
 
93
  log_message += f"\n相似度 < {SIMILARITY_THRESHOLD},[模式: LLM生成]。正在構建 Prompt..."
 
 
94
  examples_context = ""
95
+ for hit in hits[:3]:
96
  examples_context += f"### A user asks: {questions[hit['corpus_id']]}\n{sql_answers[hit['corpus_id']]}\n\n"
 
97
  prompt = f"""### Task
98
  Generate a SQLite SQL query that answers the following user question.
99
  Your response must contain ONLY the SQL query. Do not add any explanation.
 
107
 
108
  ### SQL Query
109
  """
 
 
110
  log_message += "\n正在請求雲端 LLM..."
111
  headers = {"Authorization": f"Bearer {HF_TOKEN}"}
112
  payload = {"inputs": prompt, "parameters": {"max_new_tokens": 512, "temperature": 0.1, "return_full_text": False}}
113
  response_text = ""
 
114
  try:
115
  response = requests.post(LLM_API_URL, headers=headers, json=payload)
116
+ response_text = response.text
117
  response.raise_for_status()
 
118
  generated_text = response.json()[0]['generated_text'].strip()
 
 
119
  if "```sql" in generated_text:
120
  generated_text = generated_text.split("```sql")[1].split("```").strip()
121
  if "```" in generated_text:
122
  generated_text = generated_text.replace("```", "").strip()
 
 
123
  log_message += f"\nLLM 生成成功!"
124
  return generated_text, log_message
125
  except Exception as e:
 
130
  # --- 4. 創建 Gradio Web 界面 ---
131
  print("--- [5/5] 正在創建 Gradio Web 界面... ---")
132
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
133
+ # (此部分無需修改,保持原樣)
134
  gr.Markdown("# 智能 Text-to-SQL 系統 (混合模式)")
135
  gr.Markdown("輸入您的自然語言問題,系統將首先嘗試從知識庫中快速檢索答案。如果問題較新穎,則會調用雲端大語言模型生成SQL。")
 
136
  with gr.Row():
137
  question_input = gr.Textbox(label="輸入您的問題", placeholder="例如:去年Nike的總業績是多少?", scale=4)
138
  submit_button = gr.Button("生成SQL", variant="primary", scale=1)
 
139
  sql_output = gr.Code(label="生成的 SQL 查詢", language="sql")
140
  log_output = gr.Textbox(label="系統日誌 (執行過程)", lines=4, interactive=False)
 
141
  submit_button.click(
142
  fn=get_sql_query,
143
  inputs=question_input,
144
  outputs=[sql_output, log_output]
145
  )
 
146
  gr.Examples(
147
  examples=[
148
  "2024 最好的5個客人以及業績",