Paul720810 commited on
Commit
aa41c39
·
verified ·
1 Parent(s): 1021a18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -43
app.py CHANGED
@@ -103,6 +103,8 @@ class TextToSQLSystem:
103
  self.log_history = []
104
  self._log("初始化系統...")
105
  self.query_cache = {}
 
 
106
 
107
  # 1. 載入嵌入模型
108
  self._log(f"載入嵌入模型: {embed_model_name}")
@@ -149,7 +151,7 @@ class TextToSQLSystem:
149
  )
150
 
151
  # 使用一組更基礎、更穩定的參數來載入模型
152
- self.llm = Llama(
153
  model_path=model_path,
154
  n_ctx=2048, # 將上下文增加到 2048 以確保 Prompt 不會超長
155
  n_threads=4, # 保持 4 線程
@@ -159,7 +161,8 @@ class TextToSQLSystem:
159
  )
160
 
161
  # 簡單測試模型是否能回應
162
- self.llm("你好", max_tokens=3)
 
163
  self._log("✅ GGUF 模型載入成功")
164
 
165
  except Exception as e:
@@ -176,7 +179,7 @@ class TextToSQLSystem:
176
  repo_type="dataset"
177
  )
178
 
179
- self.llm = Llama(
180
  model_path=model_path,
181
  n_ctx=512,
182
  n_threads=4,
@@ -185,7 +188,7 @@ class TextToSQLSystem:
185
  )
186
 
187
  # 測試生成
188
- test_result = self.llm("SELECT", max_tokens=5)
189
  self._log("✅ GGUF 模型載入成功")
190
  return True
191
 
@@ -223,63 +226,96 @@ class TextToSQLSystem:
223
  pad_token_id=self.transformers_tokenizer.eos_token_id
224
  )
225
 
226
- self.llm = "transformers" # 標記使用 transformers
 
227
  self._log("✅ Transformers 模型載入成功")
228
 
229
  except Exception as e:
230
  self._log(f"❌ Transformers 載入也失敗: {e}", "ERROR")
231
- self.llm = None
232
 
233
  def huggingface_api_call(self, prompt: str) -> str:
234
- """調GGUF 模型並加入詳細的原始輸出日誌"""
235
- if self.llm is None:
236
- self._log("模型未載入,返回 fallback SQL。", "ERROR")
237
- return self._generate_fallback_sql(prompt)
238
-
239
- try:
240
- # 重要: 移除 ";" 讓模型可輸出完整查詢(包含結尾分號前所有內容)
241
- output = self.llm(
242
- prompt,
243
- max_tokens=350,
244
- temperature=0.05,
245
- top_p=0.9,
246
- echo=False,
247
- stop=["```", "\n\n", "</s>"]
248
- )
249
-
250
- self._log(f"🧠 模型原始輸出 (Raw Output): {output}", "DEBUG")
251
-
252
-
253
- if output and "choices" in output and len(output["choices"]) > 0:
254
- generated_text = output["choices"][0]["text"]
 
 
 
 
 
 
255
  self._log(f"📝 提取出的生成文本: {generated_text.strip()}", "DEBUG")
256
 
257
- # --- 新增的清理邏輯 ---
258
  lines = generated_text.strip().split('\n')
259
- # 過濾掉所有以 '--' 開頭的註解行
260
  non_comment_lines = [line for line in lines if not line.strip().startswith('--')]
261
  cleaned_text = "\n".join(non_comment_lines).strip()
262
-
263
  if cleaned_text != generated_text.strip():
264
  self._log(f"🧹 清理掉註解後的文本: {cleaned_text}", "DEBUG")
 
 
 
 
 
 
 
265
 
266
- return cleaned_text # <-- 返回清理的文本
267
- # --- 清理邏輯結束 ---
268
- else:
269
- self._log("❌ 模型的原始輸出格式不正確或為空。", "ERROR")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  return ""
271
 
272
- except Exception as e:
273
- self._log(f"模型生成過程中發生嚴重錯誤: {e}", "CRITICAL")
274
- import traceback
275
- self._log(traceback.format_exc(), "DEBUG")
276
- return ""
277
 
278
  def _load_gguf_model_fallback(self, model_path):
279
  """備用載入方式"""
280
  try:
281
  # 嘗試不同的參數組合
282
- self.llm = Llama(
283
  model_path=model_path,
284
  n_ctx=512, # 更小的上下文
285
  n_threads=4,
@@ -292,7 +328,7 @@ class TextToSQLSystem:
292
  self._log("✅ 備用方式載入成功")
293
  except Exception as e:
294
  self._log(f"❌ 備用方式也失敗: {e}", "ERROR")
295
- self.llm = None
296
 
297
  def _log(self, message: str, level: str = "INFO"):
298
  self.log_history.append(format_log(message, level))
@@ -548,7 +584,7 @@ class TextToSQLSystem:
548
  }
549
  break
550
 
551
- # ==============================================================================
552
  # 第一層:模組化意圖偵測與動態SQL組合
553
  # ==============================================================================
554
 
@@ -622,6 +658,55 @@ class TextToSQLSystem:
622
  sql_components['where'].append(f"jip.LabGroup = '{db_lab_group}'")
623
  sql_components['log_parts'].append(f"{user_input_group}組(->{db_lab_group})")
624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
  # --- 3. 判斷是否觸發了模板,並動態組合 SQL ---
626
  if 'action' in intents:
627
  sql_components['from'] = "FROM JobTimeline AS jt"
@@ -765,6 +850,7 @@ class TextToSQLSystem:
765
  User question: "{user_q}"
766
  Your single SQLite query response:
767
  ```sql
 
768
  """
769
  self._log(f"📏 Prompt 長度: {len(prompt)} 字符")
770
  # 不再需要複雜的長度截斷邏輯,因為 schema 已經被簡化
@@ -842,7 +928,7 @@ Your single SQLite query response:
842
 
843
  # --- 新增:如果是第二次嘗試,加入修正指令 ---
844
  if attempt > 0:
845
- correction_prompt = "\nYour previous attempt failed because you did not provide a valid SQL query. REMEMBER: ONLY output the SQL code inside a ```sql block. DO NOT write comments or explanations.\nSQL:\n```sql\n"
846
  # 將原本 prompt 的結尾替換成我們的修正指令
847
  prompt = prompt.rsplit("SQL:\n```sql", 1)[0] + correction_prompt
848
 
 
103
  self.log_history = []
104
  self._log("初始化系統...")
105
  self.query_cache = {}
106
+ self.backend = None # 'gguf' | 'transformers' | None
107
+ self.gguf_llm = None # 實際 llama.cpp 物件
108
 
109
  # 1. 載入嵌入模型
110
  self._log(f"載入嵌入模型: {embed_model_name}")
 
151
  )
152
 
153
  # 使用一組更基礎、更穩定的參數來載入模型
154
+ self.gguf_llm = Llama(
155
  model_path=model_path,
156
  n_ctx=2048, # 將上下文增加到 2048 以確保 Prompt 不會超長
157
  n_threads=4, # 保持 4 線程
 
161
  )
162
 
163
  # 簡單測試模型是否能回應
164
+ self.gguf_llm("你好", max_tokens=3)
165
+ self.backend = "gguf"
166
  self._log("✅ GGUF 模型載入成功")
167
 
168
  except Exception as e:
 
179
  repo_type="dataset"
180
  )
181
 
182
+ self.gguf_llm = Llama(
183
  model_path=model_path,
184
  n_ctx=512,
185
  n_threads=4,
 
188
  )
189
 
190
  # 測試生成
191
+ test_result = self.gguf_llm("SELECT", max_tokens=5)
192
  self._log("✅ GGUF 模型載入成功")
193
  return True
194
 
 
226
  pad_token_id=self.transformers_tokenizer.eos_token_id
227
  )
228
 
229
+ # 標記目前後端為 transformers
230
+ self.backend = "transformers"
231
  self._log("✅ Transformers 模型載入成功")
232
 
233
  except Exception as e:
234
  self._log(f"❌ Transformers 載入也失敗: {e}", "ERROR")
 
235
 
236
  def huggingface_api_call(self, prompt: str) -> str:
237
+ """生成 SQL:優先使transformers,其次 gguf最後 fallback"""
238
+ # transformers 後端
239
+ if self.backend == "transformers" and hasattr(self, "generation_pipeline"):
240
+ try:
241
+ gen = self.generation_pipeline(
242
+ prompt,
243
+ max_new_tokens=350,
244
+ do_sample=True,
245
+ temperature=0.05,
246
+ top_p=0.9
247
+ )
248
+ # 盡量從 pipeline 結果提取文字
249
+ generated_text = ""
250
+ try:
251
+ if isinstance(gen, list) and gen:
252
+ first = gen[0]
253
+ if isinstance(first, dict) and "generated_text" in first:
254
+ generated_text = str(first["generated_text"]) # type: ignore[index]
255
+ else:
256
+ generated_text = str(first)
257
+ else:
258
+ generated_text = str(gen)
259
+ except Exception:
260
+ generated_text = str(gen)
261
+ # 若包含 prompt,裁切前綴
262
+ if isinstance(generated_text, str) and generated_text.startswith(prompt):
263
+ generated_text = generated_text[len(prompt):]
264
  self._log(f"📝 提取出的生成文本: {generated_text.strip()}", "DEBUG")
265
 
 
266
  lines = generated_text.strip().split('\n')
 
267
  non_comment_lines = [line for line in lines if not line.strip().startswith('--')]
268
  cleaned_text = "\n".join(non_comment_lines).strip()
 
269
  if cleaned_text != generated_text.strip():
270
  self._log(f"🧹 清理掉註解後的文本: {cleaned_text}", "DEBUG")
271
+ if cleaned_text and not re.match(r"^\s*select\b", cleaned_text, flags=re.IGNORECASE):
272
+ self._log("⚙️ 補上缺失的 'SELECT ' 起手以形成完整查詢", "DEBUG")
273
+ cleaned_text = "SELECT " + cleaned_text.lstrip()
274
+ return cleaned_text
275
+ except Exception as e:
276
+ self._log(f"❌ Transformers 生成失敗: {e}", "ERROR")
277
+ return ""
278
 
279
+ # gguf
280
+ if self.backend == "gguf" and self.gguf_llm is not None and callable(getattr(self.gguf_llm, "__call__", None)):
281
+ try:
282
+ output = self.gguf_llm(
283
+ prompt,
284
+ max_tokens=350,
285
+ temperature=0.05,
286
+ top_p=0.9,
287
+ echo=False,
288
+ stop=["```"]
289
+ )
290
+ self._log(f"🧠 模型原始輸出 (Raw Output): {output}", "DEBUG")
291
+ if output and "choices" in output and len(output["choices"]) > 0:
292
+ generated_text = output["choices"][0]["text"]
293
+ self._log(f"📝 提取出的生成文本: {generated_text.strip()}", "DEBUG")
294
+ lines = str(generated_text).strip().split('\n')
295
+ non_comment_lines = [line for line in lines if not line.strip().startswith('--')]
296
+ cleaned_text = "\n".join(non_comment_lines).strip()
297
+ if cleaned_text != str(generated_text).strip():
298
+ self._log(f"🧹 清理掉註解後的文本: {cleaned_text}", "DEBUG")
299
+ if cleaned_text and not re.match(r"^\s*select\b", cleaned_text, flags=re.IGNORECASE):
300
+ self._log("⚙️ 補上缺失的 'SELECT ' 起手以形成完整查詢", "DEBUG")
301
+ cleaned_text = "SELECT " + cleaned_text.lstrip()
302
+ return cleaned_text
303
+ else:
304
+ self._log("❌ 模型的原始輸出格式不正確或為空。", "ERROR")
305
+ return ""
306
+ except Exception as e:
307
+ self._log(f"❌ GGUF 生成失敗: {e}", "ERROR")
308
  return ""
309
 
310
+ # 後備:都不可用時,回退
311
+ self._log("模型未載入或不可用,返回 fallback SQL。", "ERROR")
312
+ return self._generate_fallback_sql(prompt)
 
 
313
 
314
  def _load_gguf_model_fallback(self, model_path):
315
  """備用載入方式"""
316
  try:
317
  # 嘗試不同的參數組合
318
+ self.gguf_llm = Llama(
319
  model_path=model_path,
320
  n_ctx=512, # 更小的上下文
321
  n_threads=4,
 
328
  self._log("✅ 備用方式載入成功")
329
  except Exception as e:
330
  self._log(f"❌ 備用方式也失敗: {e}", "ERROR")
331
+ self.gguf_llm = None
332
 
333
  def _log(self, message: str, level: str = "INFO"):
334
  self.log_history.append(format_log(message, level))
 
584
  }
585
  break
586
 
587
+ # ==============================================================================
588
  # 第一層:模組化意圖偵測與動態SQL組合
589
  # ==============================================================================
590
 
 
658
  sql_components['where'].append(f"jip.LabGroup = '{db_lab_group}'")
659
  sql_components['log_parts'].append(f"{user_input_group}組(->{db_lab_group})")
660
 
661
+ # --- 2.6: 兩年份比較模板(優先級:高) ---
662
+ # 偵測『比較/vs/對比/相較/相比』字樣,擷取兩個年份與(可選)買家名稱
663
+ compare_hit = any(kw in q_lower for kw in ["比較", "對比", "相較", "相比", "vs", "versus"])
664
+ years_found = re.findall(r"(20\d{2})", question)
665
+ years_unique = []
666
+ for y in years_found:
667
+ if y not in years_unique:
668
+ years_unique.append(y)
669
+ if compare_hit and len(years_unique) >= 2:
670
+ year_a, year_b = years_unique[0], years_unique[1]
671
+ # 嘗試抓買家名稱(英文/數字/符號),若沒有則不加 buyer 條件
672
+ buyer_name = None
673
+ buyer_match = re.search(r"(?:買家|买家|buyer)\s*[::]?\s*([A-Za-z0-9&.\- ]+)", question, re.IGNORECASE)
674
+ if buyer_match:
675
+ buyer_name = buyer_match.group(1).strip()
676
+
677
+ # 判斷偏向金額或件數
678
+ amount_intent = any(kw in q_lower for kw in ["金額", "金钱", "amount", "營收", "業績", "營業��", "銷售額", "revenue"])
679
+
680
+ if amount_intent:
681
+ # 金額版:需要發票表,依架構命名使用 TSR53Invoice 與 LocalAmount;與樣本描述以 JobNo 關聯
682
+ sql = (
683
+ "SELECT strftime('%Y', jt.ReportAuthorization) AS year, "
684
+ "SUM(COALESCE(inv.LocalAmount, 0)) AS total_amount "
685
+ "FROM JobTimeline AS jt "
686
+ "JOIN TSR53SampleDescription AS sd ON sd.JobNo = jt.JobNo "
687
+ "LEFT JOIN TSR53Invoice AS inv ON inv.JobNo = jt.JobNo "
688
+ "WHERE jt.ReportAuthorization IS NOT NULL "
689
+ f"AND strftime('%Y', jt.ReportAuthorization) IN ('{year_a}', '{year_b}') "
690
+ )
691
+ if buyer_name:
692
+ sql += f"AND sd.BuyerName LIKE '%{buyer_name}%' "
693
+ sql += "GROUP BY year ORDER BY year;"
694
+ return self._finalize_sql(sql, f"模板覆寫: 兩年份金額比較 {year_a} vs {year_b}" )
695
+ else:
696
+ # 件數版:以報告數量為主,去重 JobNo
697
+ sql = (
698
+ "SELECT strftime('%Y', jt.ReportAuthorization) AS year, "
699
+ "COUNT(DISTINCT jt.JobNo) AS report_count "
700
+ "FROM JobTimeline AS jt "
701
+ "JOIN TSR53SampleDescription AS sd ON sd.JobNo = jt.JobNo "
702
+ "WHERE jt.ReportAuthorization IS NOT NULL "
703
+ f"AND strftime('%Y', jt.ReportAuthorization) IN ('{year_a}', '{year_b}') "
704
+ )
705
+ if buyer_name:
706
+ sql += f"AND sd.BuyerName LIKE '%{buyer_name}%' "
707
+ sql += "GROUP BY year ORDER BY year;"
708
+ return self._finalize_sql(sql, f"模板覆寫: 兩年份件數比較 {year_a} vs {year_b}" )
709
+
710
  # --- 3. 判斷是否觸發了模板,並動態組合 SQL ---
711
  if 'action' in intents:
712
  sql_components['from'] = "FROM JobTimeline AS jt"
 
850
  User question: "{user_q}"
851
  Your single SQLite query response:
852
  ```sql
853
+ SELECT
854
  """
855
  self._log(f"📏 Prompt 長度: {len(prompt)} 字符")
856
  # 不再需要複雜的長度截斷邏輯,因為 schema 已經被簡化
 
928
 
929
  # --- 新增:如果是第二次嘗試,加入修正指令 ---
930
  if attempt > 0:
931
+ correction_prompt = "\nYour previous attempt failed because you did not provide a valid SQL query. REMEMBER: ONLY output the SQL code inside a ```sql block. DO NOT write comments or explanations.\nSQL:\n```sql\nSELECT "
932
  # 將原本 prompt 的結尾替換成我們的修正指令
933
  prompt = prompt.rsplit("SQL:\n```sql", 1)[0] + correction_prompt
934