Paul720810 commited on
Commit
85e2894
·
verified ·
1 Parent(s): 931be3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -29
app.py CHANGED
@@ -89,41 +89,59 @@ class TextToSQLSystem:
89
  """載入 GGUF 模型並處理錯誤"""
90
  try:
91
  self._log("載入 GGUF 模型...")
 
 
92
  model_path = hf_hub_download(
93
  repo_id=GGUF_REPO_ID,
94
  filename=GGUF_FILENAME,
95
- repo_type="dataset"
 
96
  )
97
 
98
- # 檢查文件完整性
99
- file_size = os.path.getsize(model_path)
100
- expected_size = 986 * 1024 * 1024 # 986MB
101
- if file_size != expected_size:
102
- self._log(f"⚠️ 文件大小不匹配: {file_size} != {expected_size}", "WARNING")
103
- # 重新下載
104
- os.remove(model_path)
105
  model_path = hf_hub_download(
106
  repo_id=GGUF_REPO_ID,
107
  filename=GGUF_FILENAME,
108
  repo_type="dataset",
109
  force_download=True
110
  )
 
 
 
 
111
 
112
- # 使用更兼容的參數
113
  self.llm = Llama(
114
  model_path=model_path,
115
- n_ctx=1024,
116
- n_threads=max(2, os.cpu_count() - 1), # 留一個核心給系統
117
- n_batch=256,
118
- verbose=True, # 開啟詳細日誌
119
- n_gpu_layers=0 # 強制使用CPU
 
 
120
  )
121
- self._log("✅ GGUF 模型載入成功")
 
 
 
 
 
 
122
 
123
  except Exception as e:
124
- self._log(f"❌ GGUF 模型載入失敗: {e}", "ERROR")
125
- self._log("嘗試使用備用載入...")
126
- self._load_gguf_model_fallback(model_path)
 
 
 
 
127
  def _load_gguf_model_fallback(self, model_path):
128
  """備用載入方式"""
129
  try:
@@ -320,29 +338,76 @@ class TextToSQLSystem:
320
  return prompt
321
 
322
  def huggingface_api_call(self, prompt: str) -> str:
323
- """使用 GGUF 模型生成"""
324
  if self.llm is None:
325
- return "模型未載入"
 
326
 
327
  try:
328
- # 檢查prompt長度
329
- if len(prompt) > 1800:
330
- prompt = prompt[:1800] + "..."
331
 
332
  output = self.llm(
333
  prompt,
334
- max_tokens=256,
335
- temperature=0.1,
336
- top_p=0.9,
337
- stop=["</s>", "```", ";", "\n\n"],
338
  echo=False
339
  )
340
- return output["choices"][0]["text"].strip()
341
 
 
 
 
 
 
342
  except Exception as e:
343
  self._log(f"❌ 生成失敗: {e}", "ERROR")
344
- return f"生成失敗: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  def process_question(self, question: str) -> Tuple[str, str]:
347
  """處理使用者問題"""
348
  # 檢查緩存
 
89
  """載入 GGUF 模型並處理錯誤"""
90
  try:
91
  self._log("載入 GGUF 模型...")
92
+
93
+ # 強制重新下載模型
94
  model_path = hf_hub_download(
95
  repo_id=GGUF_REPO_ID,
96
  filename=GGUF_FILENAME,
97
+ repo_type="dataset",
98
+ force_download=True # 強制重新下載
99
  )
100
 
101
+ # 使用驗證方法檢查檔案
102
+ if not self._validate_model_file(model_path):
103
+ self._log("❌ 模型檔案驗證失敗,嘗試重新下載", "ERROR")
104
+ # 刪除損壞的檔案並重新下載
105
+ if os.path.exists(model_path):
106
+ os.remove(model_path)
 
107
  model_path = hf_hub_download(
108
  repo_id=GGUF_REPO_ID,
109
  filename=GGUF_FILENAME,
110
  repo_type="dataset",
111
  force_download=True
112
  )
113
+
114
+ # 再次驗證
115
+ if not self._validate_model_file(model_path):
116
+ raise ValueError("重新下載後檔案仍然無效")
117
 
118
+ # 使用更保守的參數載入模型
119
  self.llm = Llama(
120
  model_path=model_path,
121
+ n_ctx=512, # 減少上下文長度
122
+ n_threads=4, # 固定線程數
123
+ n_batch=128, # 減少批次大小
124
+ verbose=False, # 關閉詳細輸出
125
+ use_mmap=True, # 使用記憶體映射
126
+ use_mlock=False, # 不鎖定記憶體
127
+ n_gpu_layers=0 # 強制使用 CPU
128
  )
129
+
130
+ # 測試模型是否能正常生成
131
+ test_output = self.llm("SELECT", max_tokens=5, temperature=0.1)
132
+ if not test_output or 'choices' not in test_output:
133
+ raise RuntimeError("模型載入後無法正常生成")
134
+
135
+ self._log("✅ GGUF 模型載入並測試成功")
136
 
137
  except Exception as e:
138
+ self._log(f"❌ GGUF 模型載入失敗: {str(e)}", "ERROR")
139
+ self._log("嘗試使用替代...", "INFO")
140
+ self.llm = None
141
+
142
+ # 可以在這裡添加使用其他模型的邏輯
143
+ # 例如使用 Hugging Face Transformers 的備用方案
144
+
145
  def _load_gguf_model_fallback(self, model_path):
146
  """備用載入方式"""
147
  try:
 
338
  return prompt
339
 
340
  def huggingface_api_call(self, prompt: str) -> str:
341
+ """使用 GGUF 模型生成或提供替代方案"""
342
  if self.llm is None:
343
+ # 返回基於規則的簡單 SQL 生成
344
+ return self._generate_fallback_sql(prompt)
345
 
346
  try:
347
+ if len(prompt) > 1500: # 縮短提示長度
348
+ prompt = prompt[:1500] + "..."
 
349
 
350
  output = self.llm(
351
  prompt,
352
+ max_tokens=128, # 減少最大 token 數
353
+ temperature=0.0, # 使用確定性生成
354
+ top_p=0.95,
355
+ stop=["</s>", "```", "\n\n", "問題:"], # 添加更多停止詞
356
  echo=False
357
  )
 
358
 
359
+ if output and 'choices' in output and output['choices']:
360
+ return output["choices"][0]["text"].strip()
361
+ else:
362
+ return "模型生成失敗"
363
+
364
  except Exception as e:
365
  self._log(f"❌ 生成失敗: {e}", "ERROR")
366
+ return self._generate_fallback_sql(prompt)
367
+
368
+ def _generate_fallback_sql(self, prompt: str) -> str:
369
+ """當模型不可用時的備用 SQL 生成"""
370
+ prompt_lower = prompt.lower()
371
+
372
+ # 簡單的關鍵詞匹配生成基本 SQL
373
+ if "統計" in prompt or "數量" in prompt or "多少" in prompt:
374
+ if "月" in prompt:
375
+ return "SELECT strftime('%Y-%m', completed_time) as month, COUNT(*) as count FROM jobtimeline GROUP BY month ORDER BY month;"
376
+ elif "客戶" in prompt:
377
+ return "SELECT applicant, COUNT(*) as count FROM tsr53sampledescription GROUP BY applicant ORDER BY count DESC;"
378
+ else:
379
+ return "SELECT COUNT(*) as total_count FROM jobtimeline WHERE completed_time IS NOT NULL;"
380
+
381
+ elif "金額" in prompt or "總額" in prompt:
382
+ return "SELECT SUM(amount) as total_amount FROM tsr53invoice;"
383
+
384
+ elif "評級" in prompt or "pass" in prompt_lower or "fail" in prompt_lower:
385
+ return "SELECT rating, COUNT(*) as count FROM tsr53sampledescription GROUP BY rating;"
386
+
387
+ else:
388
+ return "SELECT * FROM jobtimeline LIMIT 10;"
389
 
390
+ def _validate_model_file(self, model_path):
391
+ """驗證模型檔案完整性"""
392
+ try:
393
+ if not os.path.exists(model_path):
394
+ return False
395
+
396
+ # 檢查檔案大小(至少應該有幾MB)
397
+ file_size = os.path.getsize(model_path)
398
+ if file_size < 10 * 1024 * 1024: # 小於 10MB 可能有問題
399
+ return False
400
+
401
+ # 檢查 GGUF 檔案頭部
402
+ with open(model_path, 'rb') as f:
403
+ header = f.read(8)
404
+ if not header.startswith(b'GGUF'):
405
+ return False
406
+
407
+ return True
408
+ except Exception:
409
+ return False
410
+
411
  def process_question(self, question: str) -> Tuple[str, str]:
412
  """處理使用者問題"""
413
  # 檢查緩存