Paul720810 commited on
Commit
9fbce62
·
verified ·
1 Parent(s): b27f7e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -73
app.py CHANGED
@@ -6,25 +6,29 @@ import torch
6
  import numpy as np
7
  from datetime import datetime
8
  from datasets import load_dataset
9
- from sentence_transformers import SentenceTransformer, util
10
  from huggingface_hub import hf_hub_download
11
  from llama_cpp import Llama
12
  from typing import List, Dict, Tuple, Optional
13
  import faiss
14
  from functools import lru_cache
15
 
 
 
 
 
16
  # ==================== 配置區 ====================
17
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
18
  GGUF_REPO_ID = "Paul720810/gguf-models"
19
  GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q4_k_m.gguf"
20
 
21
- FEW_SHOT_EXAMPLES_COUNT = 1 # 只使用1个最相关的范例
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
23
 
24
  print("=" * 60)
25
- print("🤖 Text-to-SQL (GGUF) 極速版系統啟動中...")
26
  print(f"📊 數據集: {DATASET_REPO_ID}")
27
- print(f"🤖 GGUF 模型: {GGUF_REPO_ID}/{GGUF_FILENAME}")
28
  print(f"💻 設備: {DEVICE}")
29
  print("=" * 60)
30
 
@@ -59,38 +63,45 @@ def parse_sql_from_response(response_text: str) -> Optional[str]:
59
 
60
  # ==================== Text-to-SQL 核心類 ====================
61
  class TextToSQLSystem:
62
- def __init__(self, embed_model='all-MiniLM-L6-v2'):
63
  self.log_history = []
64
- self._log("初始化極速系統...")
65
  self.query_cache = {}
66
 
67
- # 並行載入所有組件
68
- import threading
69
- self.schema = {}
70
- self.model = None
71
- self.dataset = None
72
- self.corpus_embeddings = None
73
- self.faiss_index = None
74
- self.llm = None
75
-
76
- threads = [
77
- threading.Thread(target=self._load_schema),
78
- threading.Thread(target=self._load_embedding_model),
79
- threading.Thread(target=self._load_gguf_model)
80
- ]
81
-
82
- for t in threads:
83
- t.start()
84
- for t in threads:
85
- t.join()
86
-
87
- self._log("✅ 所有組件載入完成")
 
 
 
 
 
 
 
88
 
89
  def _log(self, message: str, level: str = "INFO"):
90
  self.log_history.append(format_log(message, level))
91
  print(format_log(message, level))
92
 
93
- def _load_schema(self):
94
  """載入數據庫結構"""
95
  try:
96
  schema_path = hf_hub_download(
@@ -99,51 +110,58 @@ class TextToSQLSystem:
99
  repo_type="dataset"
100
  )
101
  with open(schema_path, "r", encoding="utf-8") as f:
102
- self.schema = json.load(f)
103
  self._log("✅ 數據庫結構載入完成")
 
104
  except Exception as e:
105
  self._log(f"❌ 載入 schema 失敗: {e}", "ERROR")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- def _load_embedding_model(self):
108
- """載入檢索模型和數據"""
109
  try:
110
- self.model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
111
  dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
112
- self.dataset = dataset
113
  corpus = [item['messages'][0]['content'] for item in dataset]
114
  self._log(f"正在編碼 {len(corpus)} 個問題...")
115
 
116
- embeddings = self.model.encode(corpus, convert_to_tensor=True, device=DEVICE)
117
- self.corpus_embeddings = embeddings
 
 
 
 
 
 
 
 
 
118
 
119
  # 建立 FAISS 索引
120
- embeddings_np = embeddings.cpu().numpy()
121
- self.faiss_index = faiss.IndexFlatIP(embeddings_np.shape[1])
122
- self.faiss_index.add(embeddings_np)
 
 
123
 
124
- self._log("✅ FAISS 向量索引建立完成")
125
- except Exception as e:
126
- self._log(f"❌ 載入檢索模型失敗: {e}", "ERROR")
127
-
128
- def _load_gguf_model(self):
129
- """載入 GGUF 模型"""
130
- try:
131
- model_path = hf_hub_download(
132
- repo_id=GGUF_REPO_ID,
133
- filename=GGUF_FILENAME,
134
- repo_type="dataset"
135
- )
136
- self.llm = Llama(
137
- model_path=model_path,
138
- n_ctx=1024,
139
- n_threads=os.cpu_count(),
140
- n_batch=512,
141
- n_gpu_layers=0,
142
- verbose=False
143
- )
144
- self._log("✅ GGUF 模型載入完成")
145
  except Exception as e:
146
- self._log(f"❌ 載入 GGUF 模型失敗: {e}", "ERROR")
 
147
 
148
  def _identify_relevant_tables(self, question: str) -> List[str]:
149
  """智能識別問題相關的表"""
@@ -171,13 +189,13 @@ class TextToSQLSystem:
171
  if not self.schema:
172
  return "無數據庫結構信息"
173
 
174
- formatted = "相關表結構:\n"
175
  for table in table_names:
176
  if table in self.schema:
177
- formatted += f"## {table}\n"
178
  for col in self.schema[table][:6]: # 只顯示前6個列
179
  col_desc = col.get('description', '')
180
- formatted += f"- {col['name']} ({col['type']})"
181
  if col_desc:
182
  formatted += f" # {col_desc}"
183
  formatted += "\n"
@@ -185,18 +203,17 @@ class TextToSQLSystem:
185
 
186
  return formatted
187
 
188
- @lru_cache(maxsize=100)
189
  def find_most_similar(self, question: str, top_k: int) -> List[Dict]:
190
  """使用 FAISS 快速檢索相似問題"""
191
  if self.faiss_index is None or self.dataset is None:
192
  return []
193
 
194
  try:
195
- q_emb = self.model.encode(question, convert_to_tensor=True, device=DEVICE)
196
- q_emb_np = q_emb.cpu().numpy().reshape(1, -1)
197
 
198
  # FAISS 搜索
199
- distances, indices = self.faiss_index.search(q_emb_np, min(top_k + 2, len(self.dataset)))
200
 
201
  results = []
202
  seen_questions = set()
@@ -205,6 +222,9 @@ class TextToSQLSystem:
205
  if len(results) >= top_k:
206
  break
207
 
 
 
 
208
  item = self.dataset[idx]
209
  q_content = item['messages'][0]['content']
210
  a_content = item['messages'][1]['content']
@@ -236,15 +256,15 @@ class TextToSQLSystem:
236
  schema_str = self._format_relevant_schema(relevant_tables)
237
 
238
  # 極簡指令
239
- system_instruction = "生成SQL查詢。只輸出```sql...```內容。確保SQL語法正確。"
240
 
241
  # 只顯示一個最有用的範例
242
  ex_str = ""
243
  if examples:
244
  best_example = examples[0]
245
- ex_str = f"參考範例:\n問題: {best_example['question']}\nSQL: ```sql\n{best_example['sql']}\n```\n\n"
246
 
247
- prompt = f"{system_instruction}\n{schema_str}\n{ex_str}問題: {user_q}\nSQL:"
248
 
249
  # 檢查長度,如果太長則進一步精簡
250
  if len(prompt) > 1500:
@@ -295,6 +315,8 @@ class TextToSQLSystem:
295
  # 檢索相似範例
296
  self._log("🔍 尋找相似範例...")
297
  examples = self.find_most_similar(question, FEW_SHOT_EXAMPLES_COUNT)
 
 
298
 
299
  # 建立提示詞
300
  self._log("📝 建立 Prompt...")
@@ -339,9 +361,9 @@ examples = [
339
  "A組昨天完成了多少個測試項目?"
340
  ]
341
 
342
- with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 極速助手") as demo:
343
- gr.Markdown("# ⚡ Text-to-SQL 極速助手 (GGUF)")
344
- gr.Markdown("輸入自然語言問題,自動生成SQL查詢")
345
 
346
  with gr.Row():
347
  with gr.Column(scale=2):
 
6
  import numpy as np
7
  from datetime import datetime
8
  from datasets import load_dataset
 
9
  from huggingface_hub import hf_hub_download
10
  from llama_cpp import Llama
11
  from typing import List, Dict, Tuple, Optional
12
  import faiss
13
  from functools import lru_cache
14
 
15
+ # 使用 transformers 替代 sentence-transformers
16
+ from transformers import AutoModel, AutoTokenizer
17
+ import torch.nn.functional as F
18
+
19
  # ==================== 配置區 ====================
20
  DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
21
  GGUF_REPO_ID = "Paul720810/gguf-models"
22
  GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q4_k_m.gguf"
23
 
24
+ FEW_SHOT_EXAMPLES_COUNT = 1
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
+ EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
27
 
28
  print("=" * 60)
29
+ print("🤖 Text-to-SQL 系統啟動中...")
30
  print(f"📊 數據集: {DATASET_REPO_ID}")
31
+ print(f"🤖 嵌入模型: {EMBED_MODEL_NAME}")
32
  print(f"💻 設備: {DEVICE}")
33
  print("=" * 60)
34
 
 
63
 
64
  # ==================== Text-to-SQL 核心類 ====================
65
  class TextToSQLSystem:
66
+ def __init__(self, embed_model_name=EMBED_MODEL_NAME):
67
  self.log_history = []
68
+ self._log("初始化系統...")
69
  self.query_cache = {}
70
 
71
+ # 1. 載入嵌入模型(使用 transformers)
72
+ self._log(f"載入嵌入模型: {embed_model_name}")
73
+ self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
74
+ self.embed_model = AutoModel.from_pretrained(embed_model_name)
75
+ if DEVICE == "cuda":
76
+ self.embed_model = self.embed_model.cuda()
77
+
78
+ # 2. 載入數據庫結構
79
+ self.schema = self._load_schema()
80
+
81
+ # 3. 載入數據集並建立索引
82
+ self.dataset, self.faiss_index = self._load_and_index_dataset()
83
+
84
+ # 4. 載入 GGUF 模型
85
+ self._log("載入 GGUF 模型...")
86
+ model_path = hf_hub_download(
87
+ repo_id=GGUF_REPO_ID,
88
+ filename=GGUF_FILENAME,
89
+ repo_type="dataset"
90
+ )
91
+ self.llm = Llama(
92
+ model_path=model_path,
93
+ n_ctx=1024,
94
+ n_threads=os.cpu_count(),
95
+ n_batch=512,
96
+ verbose=False
97
+ )
98
+ self._log("✅ 系統初始化完成")
99
 
100
  def _log(self, message: str, level: str = "INFO"):
101
  self.log_history.append(format_log(message, level))
102
  print(format_log(message, level))
103
 
104
+ def _load_schema(self) -> Dict:
105
  """載入數據庫結構"""
106
  try:
107
  schema_path = hf_hub_download(
 
110
  repo_type="dataset"
111
  )
112
  with open(schema_path, "r", encoding="utf-8") as f:
 
113
  self._log("✅ 數據庫結構載入完成")
114
+ return json.load(f)
115
  except Exception as e:
116
  self._log(f"❌ 載入 schema 失敗: {e}", "ERROR")
117
+ return {}
118
+
119
+ def _encode_texts(self, texts):
120
+ """編碼文本為嵌入向量"""
121
+ if isinstance(texts, str):
122
+ texts = [texts]
123
+
124
+ inputs = self.embed_tokenizer(texts, padding=True, truncation=True,
125
+ return_tensors="pt", max_length=512)
126
+ if DEVICE == "cuda":
127
+ inputs = {k: v.cuda() for k, v in inputs.items()}
128
+
129
+ with torch.no_grad():
130
+ outputs = self.embed_model(**inputs)
131
+
132
+ # 使用平均池化
133
+ embeddings = outputs.last_hidden_state.mean(dim=1)
134
+ return embeddings.cpu()
135
 
136
+ def _load_and_index_dataset(self):
137
+ """載入數據集並建立 FAISS 索引"""
138
  try:
 
139
  dataset = load_dataset(DATASET_REPO_ID, data_files="training_data.jsonl", split="train")
 
140
  corpus = [item['messages'][0]['content'] for item in dataset]
141
  self._log(f"正在編碼 {len(corpus)} 個問題...")
142
 
143
+ # 批量編碼
144
+ embeddings_list = []
145
+ batch_size = 32
146
+
147
+ for i in range(0, len(corpus), batch_size):
148
+ batch_texts = corpus[i:i+batch_size]
149
+ batch_embeddings = self._encode_texts(batch_texts)
150
+ embeddings_list.append(batch_embeddings)
151
+ self._log(f"已編碼 {min(i+batch_size, len(corpus))}/{len(corpus)}")
152
+
153
+ all_embeddings = torch.cat(embeddings_list, dim=0).numpy()
154
 
155
  # 建立 FAISS 索引
156
+ index = faiss.IndexFlatIP(all_embeddings.shape[1])
157
+ index.add(all_embeddings.astype('float32'))
158
+
159
+ self._log("✅ 向量索引建立完成")
160
+ return dataset, index
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  except Exception as e:
163
+ self._log(f"❌ 載入數據失敗: {e}", "ERROR")
164
+ return None, None
165
 
166
  def _identify_relevant_tables(self, question: str) -> List[str]:
167
  """智能識別問題相關的表"""
 
189
  if not self.schema:
190
  return "無數據庫結構信息"
191
 
192
+ formatted = "## 相關表結構:\n\n"
193
  for table in table_names:
194
  if table in self.schema:
195
+ formatted += f"### {table}\n"
196
  for col in self.schema[table][:6]: # 只顯示前6個列
197
  col_desc = col.get('description', '')
198
+ formatted += f"- **{col['name']}** ({col['type']})"
199
  if col_desc:
200
  formatted += f" # {col_desc}"
201
  formatted += "\n"
 
203
 
204
  return formatted
205
 
 
206
  def find_most_similar(self, question: str, top_k: int) -> List[Dict]:
207
  """使用 FAISS 快速檢索相似問題"""
208
  if self.faiss_index is None or self.dataset is None:
209
  return []
210
 
211
  try:
212
+ # 編碼問題
213
+ q_embedding = self._encode_texts([question]).numpy().astype('float32')
214
 
215
  # FAISS 搜索
216
+ distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset)))
217
 
218
  results = []
219
  seen_questions = set()
 
222
  if len(results) >= top_k:
223
  break
224
 
225
+ if idx >= len(self.dataset): # 確保索引有效
226
+ continue
227
+
228
  item = self.dataset[idx]
229
  q_content = item['messages'][0]['content']
230
  a_content = item['messages'][1]['content']
 
256
  schema_str = self._format_relevant_schema(relevant_tables)
257
 
258
  # 極簡指令
259
+ system_instruction = "你是一位SQL專家。請生成準確的SQLite查詢語句。只輸出```sql...```內容。"
260
 
261
  # 只顯示一個最有用的範例
262
  ex_str = ""
263
  if examples:
264
  best_example = examples[0]
265
+ ex_str = f"## 參考範例:\n問題: {best_example['question']}\nSQL: ```sql\n{best_example['sql']}\n```\n\n"
266
 
267
+ prompt = f"{system_instruction}\n\n{schema_str}\n{ex_str}## 當前問題:\n{user_q}\n\n## SQL查詢:"
268
 
269
  # 檢查長度,如果太長則進一步精簡
270
  if len(prompt) > 1500:
 
315
  # 檢索相似範例
316
  self._log("🔍 尋找相似範例...")
317
  examples = self.find_most_similar(question, FEW_SHOT_EXAMPLES_COUNT)
318
+ if examples:
319
+ self._log(f"✅ 找到 {len(examples)} 個相似範例")
320
 
321
  # 建立提示詞
322
  self._log("📝 建立 Prompt...")
 
361
  "A組昨天完成了多少個測試項目?"
362
  ]
363
 
364
+ with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手") as demo:
365
+ gr.Markdown("# ⚡ Text-to-SQL 智能助手")
366
+ gr.Markdown("輸入自然語言問題,自動生成SQL查詢語句")
367
 
368
  with gr.Row():
369
  with gr.Column(scale=2):