Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -63,6 +63,9 @@ print(f"N_BATCH={N_BATCH}")
|
|
| 63 |
print(f"暫存目錄: {TEMP_DIR}")
|
| 64 |
print("=" * 60)
|
| 65 |
|
|
|
|
|
|
|
|
|
|
| 66 |
# ==================== 工具函數 ====================
|
| 67 |
def get_current_time():
|
| 68 |
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
@@ -136,6 +139,45 @@ def parse_sql_from_response(response_text: str) -> Optional[str]:
|
|
| 136 |
|
| 137 |
return None
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
# ==================== Text-to-SQL 核心類 ====================
|
| 140 |
class TextToSQLSystem:
|
| 141 |
def __init__(self, embed_model_name=EMBED_MODEL_NAME):
|
|
@@ -756,7 +798,10 @@ SELECT
|
|
| 756 |
return None, f"無法解析SQL。原始回應:\n{raw_response}"
|
| 757 |
|
| 758 |
self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG")
|
| 759 |
-
|
|
|
|
|
|
|
|
|
|
| 760 |
fixes_applied = []
|
| 761 |
|
| 762 |
# 方言修正
|
|
@@ -783,6 +828,15 @@ SELECT
|
|
| 783 |
fixed_sql = re.sub(pat, correct, fixed_sql, flags=re.IGNORECASE)
|
| 784 |
fixes_applied.append(f"映射 Schema: '{wrong}' -> '{correct}'")
|
| 785 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 786 |
status = "AI 生成並成功修正" if fixes_applied else "AI 生成且無需修正"
|
| 787 |
return self._finalize_sql(fixed_sql, status)
|
| 788 |
|
|
@@ -876,7 +930,12 @@ def process_query(q: str, prompt_override: str = ""):
|
|
| 876 |
return po, "override", logs
|
| 877 |
# 否則當作完整 prompt 丟給 LLM
|
| 878 |
text_to_sql_system._log("使用 prompt_override 直接調用 LLM")
|
| 879 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 880 |
fixed_sql, status_message = text_to_sql_system._validate_and_fix_sql(q or "", response)
|
| 881 |
if not fixed_sql:
|
| 882 |
fixed_sql = text_to_sql_system._generate_fallback_sql(po)
|
|
@@ -927,6 +986,55 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手 (HF Space
|
|
| 927 |
btn.click(process_query, inputs=[inp, prompt_override], outputs=[sql_out, status, logs], api_name="/predict")
|
| 928 |
inp.submit(process_query, inputs=[inp, prompt_override], outputs=[sql_out, status, logs])
|
| 929 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 930 |
if __name__ == "__main__":
|
| 931 |
demo.launch(
|
| 932 |
server_name="0.0.0.0",
|
|
|
|
| 63 |
print(f"暫存目錄: {TEMP_DIR}")
|
| 64 |
print("=" * 60)
|
| 65 |
|
| 66 |
+
# 關閉 Gradio 分析上報,減少不必要的請求與雜訊
|
| 67 |
+
os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
|
| 68 |
+
|
| 69 |
# ==================== 工具函數 ====================
|
| 70 |
def get_current_time():
|
| 71 |
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
|
|
| 139 |
|
| 140 |
return None
|
| 141 |
|
| 142 |
+
def sanitize_sql(sql_text: str) -> str:
|
| 143 |
+
"""
|
| 144 |
+
將模型輸出清理為更可執行的 SQL:
|
| 145 |
+
- 全形標點轉半形(( ) ; : , 。 等)
|
| 146 |
+
- 過濾清單符號(- 開頭)
|
| 147 |
+
- 僅保留第一個 SELECT 片段直到分號或字串結尾
|
| 148 |
+
- 簡易平衡多餘的右括號
|
| 149 |
+
- 補齊分號
|
| 150 |
+
"""
|
| 151 |
+
if not sql_text:
|
| 152 |
+
return sql_text
|
| 153 |
+
s = sql_text.strip()
|
| 154 |
+
trans = str.maketrans({'(': '(', ')': ')', ';': ';', ':': ':', ',': ',', '。': '.', '【': '(', '】': ')'})
|
| 155 |
+
s = s.translate(trans)
|
| 156 |
+
cleaned_lines = []
|
| 157 |
+
for line in s.splitlines():
|
| 158 |
+
line = line.strip()
|
| 159 |
+
if line.startswith('- '):
|
| 160 |
+
continue
|
| 161 |
+
cleaned_lines.append(line)
|
| 162 |
+
s = ' '.join(cleaned_lines)
|
| 163 |
+
m = re.search(r"(SELECT\s+.*?)(;|$)", s, flags=re.IGNORECASE | re.DOTALL)
|
| 164 |
+
if m:
|
| 165 |
+
s = m.group(1)
|
| 166 |
+
open_cnt, close_cnt = s.count('('), s.count(')')
|
| 167 |
+
if close_cnt > open_cnt:
|
| 168 |
+
excess = close_cnt - open_cnt
|
| 169 |
+
out = []
|
| 170 |
+
for ch in s[::-1]:
|
| 171 |
+
if ch == ')' and excess > 0:
|
| 172 |
+
excess -= 1
|
| 173 |
+
continue
|
| 174 |
+
out.append(ch)
|
| 175 |
+
s = ''.join(out[::-1])
|
| 176 |
+
s = s.rstrip(' .)')
|
| 177 |
+
if s and not s.endswith(';'):
|
| 178 |
+
s += ';'
|
| 179 |
+
return s
|
| 180 |
+
|
| 181 |
# ==================== Text-to-SQL 核心類 ====================
|
| 182 |
class TextToSQLSystem:
|
| 183 |
def __init__(self, embed_model_name=EMBED_MODEL_NAME):
|
|
|
|
| 798 |
return None, f"無法解析SQL。原始回應:\n{raw_response}"
|
| 799 |
|
| 800 |
self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG")
|
| 801 |
+
normalized_sql = sanitize_sql(parsed_sql)
|
| 802 |
+
if normalized_sql != parsed_sql:
|
| 803 |
+
self._log(f"🧹 清理後 SQL: {normalized_sql}", "DEBUG")
|
| 804 |
+
fixed_sql = " " + normalized_sql.strip() + " "
|
| 805 |
fixes_applied = []
|
| 806 |
|
| 807 |
# 方言修正
|
|
|
|
| 828 |
fixed_sql = re.sub(pat, correct, fixed_sql, flags=re.IGNORECASE)
|
| 829 |
fixes_applied.append(f"映射 Schema: '{wrong}' -> '{correct}'")
|
| 830 |
|
| 831 |
+
# 若沒有 FROM,補上預設資料來源
|
| 832 |
+
if re.search(r"\bSELECT\b", fixed_sql, re.IGNORECASE) and not re.search(r"\bFROM\b", fixed_sql, re.IGNORECASE):
|
| 833 |
+
if re.search(r"COUNT\s*\(\s*\*\s*\)", fixed_sql, re.IGNORECASE):
|
| 834 |
+
fixed_sql = " SELECT COUNT(DISTINCT jt.JobNo) FROM JobTimeline AS jt WHERE jt.ReportAuthorization IS NOT NULL "
|
| 835 |
+
fixes_applied.append("補上預設 FROM JobTimeline (COUNT 專用)")
|
| 836 |
+
else:
|
| 837 |
+
fixed_sql = " SELECT * FROM JobTimeline AS jt WHERE jt.ReportAuthorization IS NOT NULL "
|
| 838 |
+
fixes_applied.append("補上預設 FROM JobTimeline")
|
| 839 |
+
|
| 840 |
status = "AI 生成並成功修正" if fixes_applied else "AI 生成且無需修正"
|
| 841 |
return self._finalize_sql(fixed_sql, status)
|
| 842 |
|
|
|
|
| 930 |
return po, "override", logs
|
| 931 |
# 否則當作完整 prompt 丟給 LLM
|
| 932 |
text_to_sql_system._log("使用 prompt_override 直接調用 LLM")
|
| 933 |
+
constrained_po = (
|
| 934 |
+
po.rstrip()
|
| 935 |
+
+ "\n\nReturn only the final SQL query in a fenced code block (```sql ... ```). "
|
| 936 |
+
+ "Do not output narration, bullets, or explanations. The SQL must start with SELECT and end with a semicolon."
|
| 937 |
+
)
|
| 938 |
+
response = text_to_sql_system.huggingface_api_call(constrained_po)
|
| 939 |
fixed_sql, status_message = text_to_sql_system._validate_and_fix_sql(q or "", response)
|
| 940 |
if not fixed_sql:
|
| 941 |
fixed_sql = text_to_sql_system._generate_fallback_sql(po)
|
|
|
|
| 986 |
btn.click(process_query, inputs=[inp, prompt_override], outputs=[sql_out, status, logs], api_name="/predict")
|
| 987 |
inp.submit(process_query, inputs=[inp, prompt_override], outputs=[sql_out, status, logs])
|
| 988 |
|
| 989 |
+
# ========== 健康檢查端點 /health ==========
|
| 990 |
+
@demo.add_server_route("/health", methods=["GET"]) # type: ignore[attr-defined]
|
| 991 |
+
def health_endpoint():
|
| 992 |
+
endpoints = []
|
| 993 |
+
try:
|
| 994 |
+
cfg = getattr(demo, "config", None)
|
| 995 |
+
if isinstance(cfg, dict):
|
| 996 |
+
deps = cfg.get("dependencies") or []
|
| 997 |
+
for dep in deps:
|
| 998 |
+
endpoints.append({
|
| 999 |
+
"api_name": dep.get("api_name"),
|
| 1000 |
+
"fn_index": dep.get("fn_index"),
|
| 1001 |
+
"inputs_count": len(dep.get("inputs") or []),
|
| 1002 |
+
"outputs_count": len(dep.get("outputs") or []),
|
| 1003 |
+
})
|
| 1004 |
+
except Exception:
|
| 1005 |
+
pass
|
| 1006 |
+
|
| 1007 |
+
if not endpoints:
|
| 1008 |
+
endpoints.append({
|
| 1009 |
+
"api_name": "/predict",
|
| 1010 |
+
"fn_index": None,
|
| 1011 |
+
"inputs_count": 2,
|
| 1012 |
+
"outputs_count": 3,
|
| 1013 |
+
})
|
| 1014 |
+
|
| 1015 |
+
env_info = {
|
| 1016 |
+
"USE_GPU": USE_GPU,
|
| 1017 |
+
"DEVICE": DEVICE,
|
| 1018 |
+
"N_GPU_LAYERS": N_GPU_LAYERS,
|
| 1019 |
+
"THREADS": THREADS,
|
| 1020 |
+
"CTX": CTX,
|
| 1021 |
+
"MAX_TOKENS": MAX_TOKENS,
|
| 1022 |
+
"FEW_SHOT_EXAMPLES_COUNT": FEW_SHOT_EXAMPLES_COUNT,
|
| 1023 |
+
"ENABLE_INDEX": ENABLE_INDEX,
|
| 1024 |
+
"EMBED_BATCH": EMBED_BATCH,
|
| 1025 |
+
"N_BATCH": N_BATCH,
|
| 1026 |
+
"GGUF_REPO_ID": GGUF_REPO_ID,
|
| 1027 |
+
"GGUF_FILENAME": GGUF_FILENAME,
|
| 1028 |
+
}
|
| 1029 |
+
|
| 1030 |
+
server_info = {
|
| 1031 |
+
"time": get_current_time(),
|
| 1032 |
+
"gradio_version": getattr(gr, "__version__", "unknown"),
|
| 1033 |
+
"pid": os.getpid(),
|
| 1034 |
+
}
|
| 1035 |
+
|
| 1036 |
+
return {"status": "ok", "endpoints": endpoints, "env": env_info, "server": server_info}
|
| 1037 |
+
|
| 1038 |
if __name__ == "__main__":
|
| 1039 |
demo.launch(
|
| 1040 |
server_name="0.0.0.0",
|