Paul720810 commited on
Commit
99cea8f
·
verified ·
1 Parent(s): d254318

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -75
app.py CHANGED
@@ -126,58 +126,35 @@ class TextToSQLSystem:
126
  print(f" - {col['name']} ({col['type']})")
127
  print("=" * 50)
128
 
 
 
129
  def _load_gguf_model(self):
130
- """載入 GGUF 模型,失敗則使用 Transformers 備用方案"""
131
- # 先嘗試原本的 GGUF 載入方式
132
  try:
133
- self._log("載入 GGUF 模型...")
134
  model_path = hf_hub_download(
135
  repo_id=GGUF_REPO_ID,
136
  filename=GGUF_FILENAME,
137
- repo_type="dataset",
138
- force_download=True
139
  )
140
 
141
- # 你原本的載入參數
142
  self.llm = Llama(
143
  model_path=model_path,
144
- n_ctx=1024, # 增加到 1024 2048
145
- n_threads=4,
146
- n_batch=64, # 減少批次大小
147
- verbose=False,
148
- use_mmap=True,
149
- use_mlock=False,
150
- n_gpu_layers=0,
151
- max_tokens=150 # 限制最大生成長度
152
  )
153
 
154
- # 測試是否能正常生成
155
- test_output = self.llm("SELECT", max_tokens=5, temperature=0.1)
156
  self._log("✅ GGUF 模型載入成功")
157
- return
158
 
159
  except Exception as e:
160
  self._log(f"❌ GGUF 載入失敗: {e}", "ERROR")
161
-
162
- # GGUF 失敗,使用 Transformers 載入你的微調模型
163
- try:
164
- self._log("改用 Transformers 載入微調模型...")
165
- from transformers import AutoModelForCausalLM, AutoTokenizer
166
- import torch
167
-
168
- self.transformers_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL_PATH)
169
- self.transformers_model = AutoModelForCausalLM.from_pretrained(
170
- FINETUNED_MODEL_PATH,
171
- torch_dtype=torch.float32,
172
- device_map="cpu",
173
- trust_remote_code=True
174
- )
175
-
176
- self.llm = "transformers" # 標記使用 transformers
177
- self._log("✅ Transformers 模型載入成功")
178
-
179
- except Exception as e:
180
- self._log(f"❌ Transformers 載入也失敗: {e}", "ERROR")
181
  self.llm = None
182
 
183
  def _try_gguf_loading(self):
@@ -244,53 +221,41 @@ class TextToSQLSystem:
244
  self.llm = None
245
 
246
  def huggingface_api_call(self, prompt: str) -> str:
247
- """使用更嚴格的長度限制"""
248
  if self.llm is None:
 
249
  return self._generate_fallback_sql(prompt)
250
 
251
  try:
252
- # 確保 prompt 不超過限制
253
- if len(prompt) > 600:
254
- prompt = prompt[:600] + "..."
 
 
 
 
 
 
 
255
 
256
- if self.llm == "transformers":
257
- inputs = self.transformers_tokenizer(prompt, return_tensors="pt",
258
- truncation=True, max_length=400) # 減少輸入長度
259
-
260
- with torch.no_grad():
261
- outputs = self.transformers_model.generate(
262
- inputs.input_ids,
263
- attention_mask=inputs.attention_mask,
264
- max_new_tokens=80, # 減少生成長度
265
- temperature=0.1,
266
- do_sample=True,
267
- top_p=0.9,
268
- pad_token_id=self.transformers_tokenizer.eos_token_id,
269
- eos_token_id=self.transformers_tokenizer.eos_token_id
270
- )
271
-
272
- generated_text = self.transformers_tokenizer.decode(
273
- outputs[0][inputs.input_ids.shape[1]:],
274
- skip_special_tokens=True
275
- )
276
-
277
- return generated_text.strip()
278
 
 
 
 
 
 
279
  else:
280
- # GGUF 模型
281
- output = self.llm(
282
- prompt,
283
- max_tokens=100, # 減少最大生成長度
284
- temperature=0.1,
285
- top_p=0.9,
286
- stop=["```", ";", "\n\n", "</s>"],
287
- echo=False
288
- )
289
- return output["choices"][0]["text"].strip()
290
 
291
  except Exception as e:
292
- self._log(f"❌ 生成失敗: {e}", "ERROR")
293
- return self._generate_fallback_sql(prompt)
 
 
294
 
295
  def _load_gguf_model_fallback(self, model_path):
296
  """備用載入方式"""
 
126
  print(f" - {col['name']} ({col['type']})")
127
  print("=" * 50)
128
 
129
+ # in class TextToSQLSystem:
130
+
131
  def _load_gguf_model(self):
132
+ """載入 GGUF 模型,使用更穩定、簡潔的參數"""
 
133
  try:
134
+ self._log("載入 GGUF 模型 (使用穩定性參數)...")
135
  model_path = hf_hub_download(
136
  repo_id=GGUF_REPO_ID,
137
  filename=GGUF_FILENAME,
138
+ repo_type="dataset"
 
139
  )
140
 
141
+ # 使用一組更基礎、更穩定的參數來載入模型
142
  self.llm = Llama(
143
  model_path=model_path,
144
+ n_ctx=2048, # 將上下文增加到 2048 以確保 Prompt 不會超長
145
+ n_threads=4, # 保持 4 線程
146
+ n_batch=512, # 建議值
147
+ verbose=False, # 設為 False 避免 llama.cpp 本身的日誌干擾
148
+ n_gpu_layers=0 # 確認在 CPU 上運行
 
 
 
149
  )
150
 
151
+ # 簡單測試模型是否能回應
152
+ self.llm("你好", max_tokens=3)
153
  self._log("✅ GGUF 模型載入成功")
 
154
 
155
  except Exception as e:
156
  self._log(f"❌ GGUF 載入失敗: {e}", "ERROR")
157
+ self._log("系統將無法生成 SQL。請檢查模型檔案或 llama-cpp-python 安裝。", "CRITICAL")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  self.llm = None
159
 
160
  def _try_gguf_loading(self):
 
221
  self.llm = None
222
 
223
  def huggingface_api_call(self, prompt: str) -> str:
224
+ """調用 GGUF 模型,並加入詳細的原始輸出日誌"""
225
  if self.llm is None:
226
+ self._log("模型未載入,返回 fallback SQL。", "ERROR")
227
  return self._generate_fallback_sql(prompt)
228
 
229
  try:
230
+ # GGUF 模型呼叫
231
+ output = self.llm(
232
+ prompt,
233
+ max_tokens=150, # 給予足夠的生成長度
234
+ temperature=0.1,
235
+ top_p=0.9,
236
+ echo=False,
237
+ # 暫時移除 stop 參數,觀察最原始的輸出
238
+ # stop=["```", ";", "\n\n", "</s>"],
239
+ )
240
 
241
+ # --- 關鍵除錯步驟 ---
242
+ # 印出 llama-cpp-python 返回的完整、原始的 dictionary
243
+ self._log(f"🧠 模型原始輸出 (Raw Output): {output}", "DEBUG")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ if output and "choices" in output and len(output["choices"]) > 0:
246
+ # 從原始輸出中提取文本
247
+ generated_text = output["choices"][0]["text"]
248
+ self._log(f"📝 提取出的生成文本: {generated_text.strip()}", "DEBUG")
249
+ return generated_text.strip()
250
  else:
251
+ self._log("❌ 模型的原始輸出格式不正確或為空。", "ERROR")
252
+ return "" # 返回空字串,讓後續流程處理
 
 
 
 
 
 
 
 
253
 
254
  except Exception as e:
255
+ self._log(f"❌ 模型生成過程中發生嚴重錯誤: {e}", "CRITICAL")
256
+ import traceback
257
+ self._log(traceback.format_exc(), "DEBUG") # 印出詳細的錯誤堆疊
258
+ return "" # 返回空字串
259
 
260
  def _load_gguf_model_fallback(self, model_path):
261
  """備用載入方式"""