Paul720810 commited on
Commit
40f1973
·
verified ·
1 Parent(s): 892956c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -2
app.py CHANGED
@@ -63,6 +63,9 @@ print(f"N_BATCH={N_BATCH}")
63
  print(f"暫存目錄: {TEMP_DIR}")
64
  print("=" * 60)
65
 
 
 
 
66
  # ==================== 工具函數 ====================
67
  def get_current_time():
68
  return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
@@ -136,6 +139,45 @@ def parse_sql_from_response(response_text: str) -> Optional[str]:
136
 
137
  return None
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  # ==================== Text-to-SQL 核心類 ====================
140
  class TextToSQLSystem:
141
  def __init__(self, embed_model_name=EMBED_MODEL_NAME):
@@ -756,7 +798,10 @@ SELECT
756
  return None, f"無法解析SQL。原始回應:\n{raw_response}"
757
 
758
  self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG")
759
- fixed_sql = " " + parsed_sql.strip() + " "
 
 
 
760
  fixes_applied = []
761
 
762
  # 方言修正
@@ -783,6 +828,15 @@ SELECT
783
  fixed_sql = re.sub(pat, correct, fixed_sql, flags=re.IGNORECASE)
784
  fixes_applied.append(f"映射 Schema: '{wrong}' -> '{correct}'")
785
 
 
 
 
 
 
 
 
 
 
786
  status = "AI 生成並成功修正" if fixes_applied else "AI 生成且無需修正"
787
  return self._finalize_sql(fixed_sql, status)
788
 
@@ -876,7 +930,12 @@ def process_query(q: str, prompt_override: str = ""):
876
  return po, "override", logs
877
  # 否則當作完整 prompt 丟給 LLM
878
  text_to_sql_system._log("使用 prompt_override 直接調用 LLM")
879
- response = text_to_sql_system.huggingface_api_call(po)
 
 
 
 
 
880
  fixed_sql, status_message = text_to_sql_system._validate_and_fix_sql(q or "", response)
881
  if not fixed_sql:
882
  fixed_sql = text_to_sql_system._generate_fallback_sql(po)
@@ -927,6 +986,55 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手 (HF Space
927
  btn.click(process_query, inputs=[inp, prompt_override], outputs=[sql_out, status, logs], api_name="/predict")
928
  inp.submit(process_query, inputs=[inp, prompt_override], outputs=[sql_out, status, logs])
929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930
  if __name__ == "__main__":
931
  demo.launch(
932
  server_name="0.0.0.0",
 
63
  print(f"暫存目錄: {TEMP_DIR}")
64
  print("=" * 60)
65
 
66
+ # 關閉 Gradio 分析上報,減少不必要的請求與雜訊
67
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
68
+
69
  # ==================== 工具函數 ====================
70
  def get_current_time():
71
  return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
 
139
 
140
  return None
141
 
142
+ def sanitize_sql(sql_text: str) -> str:
143
+ """
144
+ 將模型輸出清理為更可執行的 SQL:
145
+ - 全形標點轉半形(( ) ; : , 。 等)
146
+ - 過濾清單符號(- 開頭)
147
+ - 僅保留第一個 SELECT 片段直到分號或字串結尾
148
+ - 簡易平衡多餘的右括號
149
+ - 補齊分號
150
+ """
151
+ if not sql_text:
152
+ return sql_text
153
+ s = sql_text.strip()
154
+ trans = str.maketrans({'(': '(', ')': ')', ';': ';', ':': ':', ',': ',', '。': '.', '【': '(', '】': ')'})
155
+ s = s.translate(trans)
156
+ cleaned_lines = []
157
+ for line in s.splitlines():
158
+ line = line.strip()
159
+ if line.startswith('- '):
160
+ continue
161
+ cleaned_lines.append(line)
162
+ s = ' '.join(cleaned_lines)
163
+ m = re.search(r"(SELECT\s+.*?)(;|$)", s, flags=re.IGNORECASE | re.DOTALL)
164
+ if m:
165
+ s = m.group(1)
166
+ open_cnt, close_cnt = s.count('('), s.count(')')
167
+ if close_cnt > open_cnt:
168
+ excess = close_cnt - open_cnt
169
+ out = []
170
+ for ch in s[::-1]:
171
+ if ch == ')' and excess > 0:
172
+ excess -= 1
173
+ continue
174
+ out.append(ch)
175
+ s = ''.join(out[::-1])
176
+ s = s.rstrip(' .)')
177
+ if s and not s.endswith(';'):
178
+ s += ';'
179
+ return s
180
+
181
  # ==================== Text-to-SQL 核心類 ====================
182
  class TextToSQLSystem:
183
  def __init__(self, embed_model_name=EMBED_MODEL_NAME):
 
798
  return None, f"無法解析SQL。原始回應:\n{raw_response}"
799
 
800
  self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG")
801
+ normalized_sql = sanitize_sql(parsed_sql)
802
+ if normalized_sql != parsed_sql:
803
+ self._log(f"🧹 清理後 SQL: {normalized_sql}", "DEBUG")
804
+ fixed_sql = " " + normalized_sql.strip() + " "
805
  fixes_applied = []
806
 
807
  # 方言修正
 
828
  fixed_sql = re.sub(pat, correct, fixed_sql, flags=re.IGNORECASE)
829
  fixes_applied.append(f"映射 Schema: '{wrong}' -> '{correct}'")
830
 
831
+ # 若沒有 FROM,補上預設資料來源
832
+ if re.search(r"\bSELECT\b", fixed_sql, re.IGNORECASE) and not re.search(r"\bFROM\b", fixed_sql, re.IGNORECASE):
833
+ if re.search(r"COUNT\s*\(\s*\*\s*\)", fixed_sql, re.IGNORECASE):
834
+ fixed_sql = " SELECT COUNT(DISTINCT jt.JobNo) FROM JobTimeline AS jt WHERE jt.ReportAuthorization IS NOT NULL "
835
+ fixes_applied.append("補上預設 FROM JobTimeline (COUNT 專用)")
836
+ else:
837
+ fixed_sql = " SELECT * FROM JobTimeline AS jt WHERE jt.ReportAuthorization IS NOT NULL "
838
+ fixes_applied.append("補上預設 FROM JobTimeline")
839
+
840
  status = "AI 生成並成功修正" if fixes_applied else "AI 生成且無需修正"
841
  return self._finalize_sql(fixed_sql, status)
842
 
 
930
  return po, "override", logs
931
  # 否則當作完整 prompt 丟給 LLM
932
  text_to_sql_system._log("使用 prompt_override 直接調用 LLM")
933
+ constrained_po = (
934
+ po.rstrip()
935
+ + "\n\nReturn only the final SQL query in a fenced code block (```sql ... ```). "
936
+ + "Do not output narration, bullets, or explanations. The SQL must start with SELECT and end with a semicolon."
937
+ )
938
+ response = text_to_sql_system.huggingface_api_call(constrained_po)
939
  fixed_sql, status_message = text_to_sql_system._validate_and_fix_sql(q or "", response)
940
  if not fixed_sql:
941
  fixed_sql = text_to_sql_system._generate_fallback_sql(po)
 
986
  btn.click(process_query, inputs=[inp, prompt_override], outputs=[sql_out, status, logs], api_name="/predict")
987
  inp.submit(process_query, inputs=[inp, prompt_override], outputs=[sql_out, status, logs])
988
 
989
+ # ========== 健康檢查端點 /health ==========
990
+ @demo.add_server_route("/health", methods=["GET"]) # type: ignore[attr-defined]
991
+ def health_endpoint():
992
+ endpoints = []
993
+ try:
994
+ cfg = getattr(demo, "config", None)
995
+ if isinstance(cfg, dict):
996
+ deps = cfg.get("dependencies") or []
997
+ for dep in deps:
998
+ endpoints.append({
999
+ "api_name": dep.get("api_name"),
1000
+ "fn_index": dep.get("fn_index"),
1001
+ "inputs_count": len(dep.get("inputs") or []),
1002
+ "outputs_count": len(dep.get("outputs") or []),
1003
+ })
1004
+ except Exception:
1005
+ pass
1006
+
1007
+ if not endpoints:
1008
+ endpoints.append({
1009
+ "api_name": "/predict",
1010
+ "fn_index": None,
1011
+ "inputs_count": 2,
1012
+ "outputs_count": 3,
1013
+ })
1014
+
1015
+ env_info = {
1016
+ "USE_GPU": USE_GPU,
1017
+ "DEVICE": DEVICE,
1018
+ "N_GPU_LAYERS": N_GPU_LAYERS,
1019
+ "THREADS": THREADS,
1020
+ "CTX": CTX,
1021
+ "MAX_TOKENS": MAX_TOKENS,
1022
+ "FEW_SHOT_EXAMPLES_COUNT": FEW_SHOT_EXAMPLES_COUNT,
1023
+ "ENABLE_INDEX": ENABLE_INDEX,
1024
+ "EMBED_BATCH": EMBED_BATCH,
1025
+ "N_BATCH": N_BATCH,
1026
+ "GGUF_REPO_ID": GGUF_REPO_ID,
1027
+ "GGUF_FILENAME": GGUF_FILENAME,
1028
+ }
1029
+
1030
+ server_info = {
1031
+ "time": get_current_time(),
1032
+ "gradio_version": getattr(gr, "__version__", "unknown"),
1033
+ "pid": os.getpid(),
1034
+ }
1035
+
1036
+ return {"status": "ok", "endpoints": endpoints, "env": env_info, "server": server_info}
1037
+
1038
  if __name__ == "__main__":
1039
  demo.launch(
1040
  server_name="0.0.0.0",