Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from fastapi import FastAPI | |
| import os | |
| import re | |
| import json | |
| import torch | |
| import numpy as np | |
| import psutil | |
| import gc | |
| import tempfile | |
| from datetime import datetime | |
| from datasets import load_dataset | |
| from huggingface_hub import hf_hub_download | |
| from llama_cpp import Llama | |
| from typing import List, Dict, Tuple, Optional | |
| import faiss | |
| from functools import lru_cache | |
| # 使用 transformers 替代 sentence-transformers | |
| from transformers import AutoModel, AutoTokenizer | |
| import torch.nn.functional as F | |
| # ==================== 配置參數 ==================== | |
| DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline" | |
| GGUF_REPO_ID = "Paul720810/gguf-models" | |
| GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q8_0.gguf" | |
| EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
| # 可配置 GPU(HF 免費方案通常只有 CPU) | |
| USE_GPU = str(os.getenv("USE_GPU", "0")).lower() in {"1", "true", "yes", "y"} | |
| try: | |
| N_GPU_LAYERS = int(os.getenv("N_GPU_LAYERS", "0")) | |
| except Exception: | |
| N_GPU_LAYERS = 0 | |
| DEVICE = "cuda" if (USE_GPU and torch.cuda.is_available()) else "cpu" | |
| # CPU 專用優化(可由環境變數覆蓋) | |
| def _int_env(name: str, default_val: int) -> int: | |
| try: | |
| return int(os.getenv(name, str(default_val))) | |
| except Exception: | |
| return default_val | |
| THREADS = _int_env("THREADS", min(4, os.cpu_count() or 2)) # llama.cpp 執行緒數 | |
| CTX = _int_env("CTX", 768 if DEVICE == "cpu" else 1024) # 上下文長度 | |
| MAX_TOKENS = _int_env("MAX_TOKENS", 60) # 生成 token 上限 | |
| FEW_SHOT_EXAMPLES_COUNT = _int_env("FEW_SHOT", 0 if DEVICE == "cpu" else 1) | |
| ENABLE_INDEX = str(os.getenv("ENABLE_INDEX", "0" if DEVICE == "cpu" else "1")).lower() in {"1", "true", "yes", "y"} | |
| EMBED_BATCH = _int_env("EMBED_BATCH", 8 if DEVICE == "cpu" else 16) | |
| N_BATCH = _int_env("N_BATCH", 128 if DEVICE == "cpu" else 256) | |
| # 使用 /tmp 作為暫存目錄 | |
| TEMP_DIR = "/tmp/text_to_sql_cache" | |
| os.makedirs(TEMP_DIR, exist_ok=True) | |
| print("=" * 60) | |
| print("Text-to-SQL 系統啟動中 (HF 版本)...") | |
| print(f"數據集: {DATASET_REPO_ID}") | |
| print(f"嵌入模型: {EMBED_MODEL_NAME}") | |
| print(f"設備: {DEVICE} (USE_GPU={USE_GPU}, N_GPU_LAYERS={N_GPU_LAYERS})") | |
| print(f"THREADS={THREADS}, CTX={CTX}, MAX_TOKENS={MAX_TOKENS}, FEW_SHOT={FEW_SHOT_EXAMPLES_COUNT}, ENABLE_INDEX={ENABLE_INDEX}, EMBED_BATCH={EMBED_BATCH}") | |
| print(f"N_BATCH={N_BATCH}") | |
| print(f"暫存目錄: {TEMP_DIR}") | |
| print("=" * 60) | |
| # 關閉 Gradio 分析上報,減少不必要的請求與雜訊 | |
| os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False") | |
| # ==================== 工具函數 ==================== | |
| def get_current_time(): | |
| return datetime.now().strftime('%Y-%m-%d %H:%M:%S') | |
| def format_log(message: str, level: str = "INFO") -> str: | |
| return f"[{get_current_time()}] [{level.upper()}] {message}" | |
| def check_memory_usage(): | |
| """檢查內存使用情況 - 簡化版本不依賴 psutil""" | |
| try: | |
| # 使用 /proc/meminfo 獲取內存信息 (Linux 環境) | |
| with open('/proc/meminfo', 'r') as f: | |
| lines = f.readlines() | |
| mem_info = {} | |
| for line in lines: | |
| if line.startswith(('MemTotal:', 'MemFree:', 'MemAvailable:')): | |
| key, value = line.split(':') | |
| mem_info[key.strip()] = int(value.strip().split()[0]) | |
| total_gb = mem_info.get('MemTotal', 0) / (1024**2) | |
| available_gb = mem_info.get('MemAvailable', mem_info.get('MemFree', 0)) / (1024**2) | |
| used_percent = ((total_gb - available_gb) / total_gb * 100) if total_gb > 0 else 0 | |
| return f"內存使用率: {used_percent:.1f}% (可用: {available_gb:.1f}GB/{total_gb:.1f}GB)" | |
| except: | |
| # 如果無法讀取 /proc/meminfo,返回簡單信息 | |
| return "內存信息: 無法獲取詳細信息" | |
| def parse_sql_from_response(response_text: str) -> Optional[str]: | |
| """從模型輸出提取 SQL""" | |
| if not response_text: | |
| return None | |
| response_text = response_text.strip() | |
| # 1. 先找 ```sql ... ``` | |
| match = re.search(r"```sql\s*\n(.*?)\n```", response_text, re.DOTALL | re.IGNORECASE) | |
| if match: | |
| return match.group(1).strip() | |
| # 2. 找任何 ``` 包圍的內容 | |
| match = re.search(r"```\s*\n?(.*?)\n?```", response_text, re.DOTALL) | |
| if match: | |
| sql_candidate = match.group(1).strip() | |
| if sql_candidate.upper().startswith('SELECT'): | |
| return sql_candidate | |
| # 3. 找 SQL 語句(更寬鬆的匹配) | |
| match = re.search(r"(SELECT\s+.*?;)", response_text, re.DOTALL | re.IGNORECASE) | |
| if match: | |
| return match.group(1).strip() | |
| # 4. 找沒有分號的 SQL | |
| match = re.search(r"(SELECT\s+.*?)(?=\n\n|\n```|$|\n[^,\s])", response_text, re.DOTALL | re.IGNORECASE) | |
| if match: | |
| sql = match.group(1).strip() | |
| if not sql.endswith(';'): | |
| sql += ';' | |
| return sql | |
| # 5. 如果包含 SELECT,嘗試提取整行 | |
| if 'SELECT' in response_text.upper(): | |
| lines = response_text.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if line.upper().startswith('SELECT'): | |
| if not line.endswith(';'): | |
| line += ';' | |
| return line | |
| return None | |
| def sanitize_sql(sql_text: str) -> str: | |
| """ | |
| 將模型輸出清理為更可執行的 SQL: | |
| - 全形標點轉半形(( ) ; : , 。 等) | |
| - 過濾清單符號(- 開頭) | |
| - 僅保留第一個 SELECT 片段直到分號或字串結尾 | |
| - 簡易平衡多餘的右括號 | |
| - 補齊分號 | |
| """ | |
| if not sql_text: | |
| return sql_text | |
| s = sql_text.strip() | |
| trans = str.maketrans({'(': '(', ')': ')', ';': ';', ':': ':', ',': ',', '。': '.', '【': '(', '】': ')'}) | |
| s = s.translate(trans) | |
| cleaned_lines = [] | |
| for line in s.splitlines(): | |
| line = line.strip() | |
| if line.startswith('- '): | |
| continue | |
| cleaned_lines.append(line) | |
| s = ' '.join(cleaned_lines) | |
| m = re.search(r"(SELECT\s+.*?)(;|$)", s, flags=re.IGNORECASE | re.DOTALL) | |
| if m: | |
| s = m.group(1) | |
| open_cnt, close_cnt = s.count('('), s.count(')') | |
| if close_cnt > open_cnt: | |
| excess = close_cnt - open_cnt | |
| out = [] | |
| for ch in s[::-1]: | |
| if ch == ')' and excess > 0: | |
| excess -= 1 | |
| continue | |
| out.append(ch) | |
| s = ''.join(out[::-1]) | |
| s = s.rstrip(' .)') | |
| if s and not s.endswith(';'): | |
| s += ';' | |
| return s | |
| # ==================== Text-to-SQL 核心類 ==================== | |
| class TextToSQLSystem: | |
| def __init__(self, embed_model_name=EMBED_MODEL_NAME): | |
| self.log_history = [] | |
| self._log("初始化系統...") | |
| self.query_cache = {} | |
| self.embed_device = DEVICE | |
| # 檢查內存狀況 | |
| self._log(check_memory_usage()) | |
| # 1. 嵌入模型(在禁用索引時略過以節省記憶體) | |
| if ENABLE_INDEX: | |
| self._log(f"載入嵌入模型: {embed_model_name}") | |
| self.embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name) | |
| self.embed_model = AutoModel.from_pretrained(embed_model_name) | |
| try: | |
| self.embed_model.to(self.embed_device) | |
| self._log(f"嵌入模型設備: {self.embed_device}") | |
| except Exception as e: | |
| self._log(f"將嵌入模型移動到設備失敗: {e}", "WARNING") | |
| self.embed_device = "cpu" | |
| else: | |
| self.embed_tokenizer = None | |
| self.embed_model = None | |
| self._log("ENABLE_INDEX=0,略過嵌入模型載入以節省記憶體") | |
| # 2. 載入數據庫結構 | |
| self.schema = self._load_schema() | |
| # 3. 載入數據集並建立索引 | |
| self.dataset, self.faiss_index = self._load_and_index_dataset() | |
| # 4. 載入 GGUF 模型(新增錯誤處理) | |
| self._load_gguf_model() | |
| self._log("系統初始化完成") | |
| def _log(self, message: str, level: str = "INFO"): | |
| self.log_history.append(format_log(message, level)) | |
| print(format_log(message, level)) | |
| def _load_gguf_model(self): | |
| """載入 GGUF 模型,針對 Paperspace 環境優化""" | |
| try: | |
| self._log("開始下載 GGUF 模型到 /tmp...") | |
| # 檢查模型是否已存在於 /tmp | |
| model_cache_path = os.path.join(TEMP_DIR, GGUF_FILENAME) | |
| if os.path.exists(model_cache_path) and self._validate_model_file(model_cache_path): | |
| self._log(f"發現快取模型: {model_cache_path}") | |
| model_path = model_cache_path | |
| else: | |
| self._log("下載新模型...") | |
| model_path = hf_hub_download( | |
| repo_id=GGUF_REPO_ID, | |
| filename=GGUF_FILENAME, | |
| repo_type="dataset", | |
| cache_dir=TEMP_DIR, | |
| resume_download=True | |
| ) | |
| self._log(f"模型下載完成: {model_path}") | |
| # 檢查內存情況 | |
| self._log(check_memory_usage()) | |
| # 使用 CPU 友好的參數載入模型(可選 GPU layers) | |
| ngl = N_GPU_LAYERS if (DEVICE == "cuda" and N_GPU_LAYERS > 0) else 0 | |
| self._log(f"載入 GGUF 模型 (n_gpu_layers={ngl}, n_threads={THREADS}, n_ctx={CTX})...") | |
| self.llm = Llama( | |
| model_path=model_path, | |
| n_ctx=CTX, # 上下文長度(CPU 默認更小) | |
| n_threads=THREADS, # 使用多執行緒 | |
| n_batch=N_BATCH, # 批處理大小(可配置) | |
| verbose=False, | |
| n_gpu_layers=ngl, # 可選 GPU 加速 | |
| use_mmap=True, # 使用內存映射減少內存占用 | |
| use_mlock=False, # 不鎖定內存 | |
| low_vram=True # 啟用低內存模式 | |
| ) | |
| # 簡單測試模型 | |
| test_result = self.llm("SELECT", max_tokens=3) | |
| self._log("GGUF 模型載入成功") | |
| # 再次檢查內存 | |
| self._log(check_memory_usage()) | |
| except Exception as e: | |
| self._log(f"GGUF 載入失敗: {e}", "ERROR") | |
| self._log("系統將無法生成 SQL。請檢查模型檔案或內存情況。", "CRITICAL") | |
| self.llm = None | |
| def _validate_model_file(self, model_path): | |
| """驗證模型檔案完整性""" | |
| try: | |
| if not os.path.exists(model_path): | |
| return False | |
| # 檢查檔案大小(至少應該有幾百MB) | |
| file_size = os.path.getsize(model_path) | |
| if file_size < 50 * 1024 * 1024: # 小於 50MB 可能有問題 | |
| return False | |
| # 檢查 GGUF 檔案頭部 | |
| with open(model_path, 'rb') as f: | |
| header = f.read(8) | |
| if not header.startswith(b'GGUF'): | |
| return False | |
| return True | |
| except Exception: | |
| return False | |
| def huggingface_api_call(self, prompt: str) -> str: | |
| """調用 GGUF 模型,並加入詳細的原始輸出日誌""" | |
| if self.llm is None: | |
| self._log("模型未載入,返回 fallback SQL。", "ERROR") | |
| return self._generate_fallback_sql(prompt) | |
| try: | |
| # 清理垃圾收集 | |
| gc.collect() | |
| start_ts = datetime.now() | |
| output = self.llm( | |
| prompt, | |
| max_tokens=MAX_TOKENS, # 生成長度可配置 | |
| temperature=0.1, | |
| top_p=0.9, | |
| echo=False, | |
| # 避免在分號處截斷 | |
| stop=["```", "\n\n", "</s>"], | |
| ) | |
| elapsed = (datetime.now() - start_ts).total_seconds() | |
| self._log(f"推論耗時: {elapsed:.2f}s", "DEBUG") | |
| self._log(f"模型原始輸出: {str(output)[:200]}...", "DEBUG") | |
| if output and "choices" in output and len(output["choices"]) > 0: | |
| generated_text = output["choices"][0]["text"] | |
| self._log(f"提取出的生成文本: {generated_text.strip()}", "DEBUG") | |
| return generated_text.strip() | |
| else: | |
| self._log("模型的原始輸出格式不正確或為空。", "ERROR") | |
| return "" | |
| except Exception as e: | |
| self._log(f"模型生成過程中發生嚴重錯誤: {e}", "CRITICAL") | |
| import traceback | |
| self._log(traceback.format_exc(), "DEBUG") | |
| return "" | |
| def _load_schema(self) -> Dict: | |
| """載入數據庫結構""" | |
| try: | |
| schema_path = hf_hub_download( | |
| repo_id=DATASET_REPO_ID, | |
| filename="sqlite_schema_FULL.json", | |
| repo_type="dataset", | |
| cache_dir=TEMP_DIR | |
| ) | |
| with open(schema_path, "r", encoding="utf-8") as f: | |
| schema_data = json.load(f) | |
| self._log(f"Schema 載入成功,包含 {len(schema_data)} 個表格:") | |
| for table_name, columns in schema_data.items(): | |
| self._log(f" - {table_name}: {len(columns)} 個欄位") | |
| self._log("數據庫結構載入完成") | |
| return schema_data | |
| except Exception as e: | |
| self._log(f"載入 schema 失敗: {e}", "ERROR") | |
| return {} | |
| def _encode_texts(self, texts): | |
| """編碼文本為嵌入向量""" | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| if (self.embed_model is None) or (self.embed_tokenizer is None): | |
| # 在禁用索引情況下不應被呼叫;保險處理 | |
| self._log("嵌入模型未載入(ENABLE_INDEX=0),_encode_texts 被略過。", "WARNING") | |
| return torch.empty((len(texts), 384)) # 回傳空張量佔位 | |
| inputs = self.embed_tokenizer(texts, padding=True, truncation=True, | |
| return_tensors="pt", max_length=512) | |
| # 移動到對應設備 | |
| try: | |
| inputs = {k: v.to(self.embed_device) for k, v in inputs.items()} | |
| except Exception: | |
| pass | |
| with torch.no_grad(): | |
| outputs = self.embed_model(**inputs) | |
| # 使用平均池化 | |
| embeddings = outputs.last_hidden_state.mean(dim=1) | |
| return embeddings.detach().cpu() | |
| def _load_and_index_dataset(self): | |
| """載入數據集並建立 FAISS 索引""" | |
| try: | |
| if not ENABLE_INDEX: | |
| self._log("已禁用相似範例索引(ENABLE_INDEX=0)。啟動更快,將不使用 few-shot。") | |
| return None, None | |
| dataset = load_dataset( | |
| DATASET_REPO_ID, | |
| data_files="training_data.jsonl", | |
| split="train", | |
| cache_dir=TEMP_DIR | |
| ) | |
| # 過濾不完整樣本 | |
| original_count = len(dataset) | |
| dataset = dataset.filter( | |
| lambda ex: isinstance(ex.get("messages"), list) | |
| and len(ex["messages"]) >= 2 | |
| and all( | |
| isinstance(m.get("content"), str) and m.get("content") and m["content"].strip() | |
| for m in ex["messages"][:2] | |
| ) | |
| ) | |
| self._log(f"資料集清理: 原始 {original_count} 筆, 過濾後 {len(dataset)} 筆") | |
| if len(dataset) == 0: | |
| self._log("清理後資料集為空,無法建立索引。", "ERROR") | |
| return None, None | |
| corpus = [item['messages'][0]['content'] for item in dataset] | |
| self._log(f"正在編碼 {len(corpus)} 個問題...") | |
| # 批量編碼以節省內存 | |
| embeddings_list = [] | |
| batch_size = EMBED_BATCH # 可配置的批次大小(CPU 預設更小) | |
| for i in range(0, len(corpus), batch_size): | |
| batch_texts = corpus[i:i+batch_size] | |
| batch_embeddings = self._encode_texts(batch_texts) | |
| embeddings_list.append(batch_embeddings) | |
| # 清理內存 | |
| if i % (batch_size * 4) == 0: | |
| gc.collect() | |
| self._log(f"已編碼 {min(i+batch_size, len(corpus))}/{len(corpus)}") | |
| all_embeddings = torch.cat(embeddings_list, dim=0).numpy() | |
| # 建立 FAISS 索引 | |
| index = faiss.IndexFlatIP(all_embeddings.shape[1]) | |
| index.add(all_embeddings.astype('float32')) | |
| # 清理內存 | |
| del embeddings_list, all_embeddings | |
| gc.collect() | |
| self._log("向量索引建立完成") | |
| return dataset, index | |
| except Exception as e: | |
| self._log(f"載入數據失敗: {e}", "ERROR") | |
| return None, None | |
| def _identify_relevant_tables(self, question: str) -> List[str]: | |
| """根據實際 Schema 識別相關表格""" | |
| question_lower = question.lower() | |
| relevant_tables = [] | |
| # 根據實際表格的關鍵詞映射 | |
| keyword_to_table = { | |
| 'TSR53SampleDescription': ['客戶', '買方', '申請', '發票對象', 'customer', 'invoice', 'sample'], | |
| 'JobsInProgress': ['進行中', '買家', '申請方', 'buyer', 'applicant', 'progress', '工作狀態'], | |
| 'JobTimeline': ['時間', '完成', '創建', '實驗室', 'timeline', 'creation', 'lab'], | |
| 'TSR53Invoice': ['發票', '金額', '費用', 'invoice', 'credit', 'amount'], | |
| 'JobEventsLog': ['事件', '操作', '用戶', 'event', 'log', 'user'], | |
| 'calendar_days': ['工作日', '假期', 'workday', 'holiday', 'calendar'] | |
| } | |
| for table, keywords in keyword_to_table.items(): | |
| if any(keyword in question_lower for keyword in keywords): | |
| relevant_tables.append(table) | |
| # 預設重要表格 | |
| if not relevant_tables: | |
| if any(word in question_lower for word in ['客戶', '買家', '申請', '工作單', '數量']): | |
| return ['TSR53SampleDescription', 'JobsInProgress'] | |
| else: | |
| return ['JobTimeline', 'TSR53SampleDescription'] | |
| return relevant_tables[:3] # 最多返回3個相關表格 | |
| def _format_relevant_schema(self, table_names: List[str]) -> str: | |
| """生成一個簡化的 Schema 字符串""" | |
| if not self.schema: | |
| return "No schema available.\n" | |
| actual_table_names_map = {name.lower(): name for name in self.schema.keys()} | |
| real_table_names = [] | |
| for table in table_names: | |
| actual_name = actual_table_names_map.get(table.lower()) | |
| if actual_name: | |
| real_table_names.append(actual_name) | |
| elif table in self.schema: | |
| real_table_names.append(table) | |
| if not real_table_names: | |
| self._log("未識別到相關表格,使用預設核心表格。", "WARNING") | |
| real_table_names = ['TSR53SampleDescription', 'JobTimeline', 'JobsInProgress'] | |
| formatted = "" | |
| for table in real_table_names: | |
| if table in self.schema: | |
| formatted += f"Table: {table}\n" | |
| cols_str = [] | |
| # 只顯示前 8 個關鍵欄位以節省內存 | |
| for col in self.schema[table][:8]: | |
| col_name = col['name'] | |
| col_type = col['type'] | |
| cols_str.append(f"{col_name} ({col_type})") | |
| formatted += f"Columns: {', '.join(cols_str)}\n\n" | |
| return formatted.strip() | |
| def find_most_similar(self, question: str, top_k: int) -> List[Dict]: | |
| """使用 FAISS 快速檢索相似問題""" | |
| if self.faiss_index is None or self.dataset is None: | |
| return [] | |
| try: | |
| # 編碼問題 | |
| q_embedding = self._encode_texts([question]).numpy().astype('float32') | |
| # FAISS 搜索 | |
| distances, indices = self.faiss_index.search(q_embedding, min(top_k + 2, len(self.dataset))) | |
| results = [] | |
| seen_questions = set() | |
| for i, idx in enumerate(indices[0]): | |
| if len(results) >= top_k: | |
| break | |
| idx = int(idx) | |
| if idx >= len(self.dataset): | |
| continue | |
| item = self.dataset[idx] | |
| if not isinstance(item.get('messages'), list) or len(item['messages']) < 2: | |
| continue | |
| q_content = (item['messages'][0].get('content') or '').strip() | |
| a_content = (item['messages'][1].get('content') or '').strip() | |
| if not q_content or not a_content: | |
| continue | |
| # 提取純淨問題 | |
| clean_q = re.sub(r"以下是一個SQL查詢任務:\s*指令:\s*", "", q_content).strip() | |
| if clean_q in seen_questions: | |
| continue | |
| seen_questions.add(clean_q) | |
| sql = parse_sql_from_response(a_content) or "無法解析範例SQL" | |
| results.append({ | |
| "similarity": float(distances[0][i]), | |
| "question": clean_q, | |
| "sql": sql | |
| }) | |
| return results | |
| except Exception as e: | |
| self._log(f"檢索失敗: {e}", "ERROR") | |
| return [] | |
| def _build_prompt(self, user_q: str, examples: List[Dict]) -> str: | |
| """建立簡化的提示詞""" | |
| relevant_tables = self._identify_relevant_tables(user_q) | |
| schema_str = self._format_relevant_schema(relevant_tables) | |
| example_str = "No example available." | |
| if examples: | |
| best_example = examples[0] | |
| example_str = f"Question: {best_example['question']}\nSQL:\n```sql\n{best_example['sql']}\n```" | |
| # 簡化的 prompt,減少 token 使用 | |
| prompt = f"""### TASK ### | |
| Generate SQLite query for the question below. | |
| ### SCHEMA ### | |
| {schema_str} | |
| ### EXAMPLE ### | |
| {example_str} | |
| ### QUESTION ### | |
| {user_q} | |
| SQL: | |
| ```sql | |
| SELECT | |
| """ | |
| return prompt | |
| def _rule_based_sql(self, question: str) -> Optional[str]: | |
| """規則先行:對常見查詢用模板直接生成 SQL,繞過 LLM。""" | |
| q = (question or "").strip() | |
| q_lower = q.lower() | |
| # 兩年比較(完成數量、每月) | |
| m = re.search(r"(20\d{2}).{0,6}(?:與|和|跟)\s*(20\d{2}).{0,10}(比較|對比).{0,10}(完成|報告|數量|件|工單)", q) | |
| if m: | |
| y1, y2 = m.group(1), m.group(2) | |
| return ( | |
| "SELECT strftime('%Y-%m', jt.ReportAuthorization) AS month, " | |
| f"COUNT(DISTINCT CASE WHEN strftime('%Y', jt.ReportAuthorization)='{y1}' THEN jt.JobNo END) AS count_{y1}, " | |
| f"COUNT(DISTINCT CASE WHEN strftime('%Y', jt.ReportAuthorization)='{y2}' THEN jt.JobNo END) AS count_{y2} " | |
| "FROM JobTimeline AS jt " | |
| "WHERE jt.ReportAuthorization IS NOT NULL " | |
| f"AND strftime('%Y', jt.ReportAuthorization) IN ('{y1}','{y2}') " | |
| "GROUP BY month ORDER BY month;" | |
| ) | |
| # 指定年份每月完成數量 | |
| m = re.search(r"(20\d{2})年.*每月.*(完成|報告|數量|件|工單)", q) | |
| if m: | |
| year = m.group(1) | |
| return ( | |
| "SELECT strftime('%Y-%m', jt.ReportAuthorization) AS month, COUNT(DISTINCT jt.JobNo) AS count " | |
| "FROM JobTimeline AS jt " | |
| "WHERE jt.ReportAuthorization IS NOT NULL " | |
| f"AND strftime('%Y', jt.ReportAuthorization)='{year}' " | |
| "GROUP BY month ORDER BY month;" | |
| ) | |
| # 評級分布(Pass/Fail) | |
| if ("評級" in q) or ("pass" in q_lower) or ("fail" in q_lower): | |
| return ( | |
| "SELECT sd.OverallRating AS rating, COUNT(*) AS count " | |
| "FROM TSR53SampleDescription AS sd " | |
| "GROUP BY sd.OverallRating;" | |
| ) | |
| # 金額最高 Top N(預設 10) | |
| m = re.search(r"金額.*?(?:最高|前|top)\s*(\d+)?", q_lower) | |
| if m: | |
| n = m.group(1) or "10" | |
| return f"SELECT iv.* FROM TSR53Invoice AS iv ORDER BY iv.LocalAmount DESC LIMIT {n};" | |
| # 客戶工作單數量最多 Top N | |
| m = re.search(r"客戶.*?(?:最多|top|前)\s*(\d+)?", q_lower) | |
| if m: | |
| n = m.group(1) or "10" | |
| return ( | |
| f"SELECT sd.ApplicantName AS applicant, COUNT(DISTINCT jt.JobNo) AS count " | |
| "FROM JobTimeline AS jt " | |
| "JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo " | |
| "WHERE jt.ReportAuthorization IS NOT NULL " | |
| "GROUP BY sd.ApplicantName ORDER BY count DESC " | |
| f"LIMIT {n};" | |
| ) | |
| # 昨天完成多少 | |
| if "昨天" in q: | |
| return ( | |
| "SELECT COUNT(DISTINCT jt.JobNo) AS count FROM JobTimeline AS jt " | |
| "WHERE jt.ReportAuthorization IS NOT NULL " | |
| "AND date(jt.ReportAuthorization)=date('now','-1 day');" | |
| ) | |
| return None | |
| def _finalize_sql(self, sql_text: str, status: str) -> Tuple[str, str]: | |
| """最終整理 SQL:補分號、去除多餘空白並回傳 (sql, 狀態)。""" | |
| try: | |
| sql_clean = (sql_text or "").strip() | |
| if sql_clean and not sql_clean.endswith(";"): | |
| sql_clean += ";" | |
| return sql_clean, status | |
| except Exception as e: | |
| self._log(f"最終整理 SQL 失敗: {e}", "ERROR") | |
| return (sql_text or ""), status | |
| def _regenerate_sql_strict(self, question: str) -> Optional[str]: | |
| """當模型輸出非 SQL 或無法解析時,使用嚴格限制的提示詞重生一次。""" | |
| try: | |
| rel = self._identify_relevant_tables(question) | |
| schema_str = self._format_relevant_schema(rel) | |
| strict_prompt = ( | |
| "You are a SQLite SQL generator.\n" | |
| + "Given the schema below and the question, output ONE valid SQL query only.\n\n" | |
| + "SCHEMA:\n" + schema_str + "\n\n" | |
| + "QUESTION:\n" + (question or "").strip() + "\n\n" | |
| + "Return only the final SQL query in a fenced code block (```sql ... ```). " | |
| + "The SQL must start with SELECT and end with a semicolon. No explanation." | |
| ) | |
| raw = self.huggingface_api_call(strict_prompt) | |
| sql = parse_sql_from_response(raw) | |
| if sql: | |
| self._log("🔁 嚴格模式重生成功。") | |
| return sql | |
| except Exception as e: | |
| self._log(f"嚴格模式重生失敗: {e}", "ERROR") | |
| return None | |
| def _validate_and_fix_sql(self, question: str, raw_response: str) -> Tuple[Optional[str], str]: | |
| """ | |
| (V29 / 穩健正則 + 智能計數) 多層次 SQL 生成: | |
| 1) 嘗試規則/模板動態組合 | |
| 2) 失敗則解析 AI 輸出並做方言/Schema 修正 | |
| 回傳: (sql 或 None, 狀態描述) | |
| """ | |
| q = question or "" | |
| q_lower = q.lower() | |
| # 先嘗試內建的規則先行器 | |
| rb = self._rule_based_sql(q) | |
| if rb: | |
| self._log("_validate_and_fix_sql 命中規則模板") | |
| return self._finalize_sql(rb, "規則生成") | |
| # 統一實體識別(簡化版) | |
| entity_match_data = None | |
| entity_patterns = [ | |
| {'pattern': r"(買家|买家|buyer)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.BuyerID', 'type': '買家ID'}, | |
| {'pattern': r"(申請方|申请方|申請廠商|申请厂商|applicant)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.ApplicantID', 'type': '申請方ID'}, | |
| {'pattern': r"(付款方|付款厂商|invoiceto)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.InvoiceToID', 'type': '付款方ID'}, | |
| {'pattern': r"(代理商|agent)\s*(?:id|代號|代碼|代号|代码)\s*'\"?\b([A-Z]\d{4}[A-Z])\b'\"?", 'column': 'sd.AgentID', 'type': '代理商ID'}, | |
| {'pattern': r"(買家|买家|buyer|客戶)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.BuyerName', 'type': '買家'}, | |
| {'pattern': r"(申請方|申请方|申請廠商|申请厂商|applicant)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.ApplicantName', 'type': '申請方'}, | |
| {'pattern': r"(付款方|付款厂商|invoiceto)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.InvoiceToName', 'type': '付款方'}, | |
| {'pattern': r"(代理商|agent)\s+([a-zA-Z0-9&.-]+)", 'column': 'sd.AgentName', 'type': '代理商'}, | |
| {'pattern': r"\b([A-Z]\d{4}[A-Z])\b", 'column': 'sd.ApplicantID', 'type': 'ID'} | |
| ] | |
| for p in entity_patterns: | |
| m = re.search(p['pattern'], q, re.IGNORECASE) | |
| if m: | |
| entity_value = m.group(2) if len(m.groups()) > 1 else m.group(1) | |
| entity_match_data = {"type": p['type'], "name": entity_value.strip().upper(), "column": p['column']} | |
| break | |
| # 模組化意圖偵測與動態 SQL 組合 | |
| intents: Dict[str, str] = {} | |
| sql = { | |
| 'select': [], 'from': '', 'joins': [], 'where': [], | |
| 'group_by': [], 'order_by': [], 'log_parts': [] | |
| } | |
| # 先處理多年份比較:如 "2021 與 2022 比較"、"2021年跟2022年對比" | |
| years = re.findall(r"(20\d{2})\s*年?", q) | |
| is_compare = re.search(r"比較|對比|對照|compare|versus|vs\.?", q) | |
| if len(set(years)) >= 2 and is_compare: | |
| ys = sorted(set(years))[:4] | |
| want_items = ("測試項目" in q) or ("item" in q_lower) | |
| select_expr = "COUNT(jip.ItemCode) AS item_count" if want_items else "COUNT(DISTINCT jt.JobNo) AS report_count" | |
| join_items = "JOIN JobItemsInProgress AS jip ON jt.JobNo = jip.JobNo" if want_items else "" | |
| where_years = " OR ".join([f"strftime('%Y', jt.ReportAuthorization) = '{y}'" for y in ys]) | |
| template = f"SELECT strftime('%Y', jt.ReportAuthorization) AS 年份, {select_expr} FROM JobTimeline AS jt {join_items} WHERE jt.ReportAuthorization IS NOT NULL AND ({where_years}) GROUP BY 年份 ORDER BY 年份;" | |
| self._log(f"🔄 多年份比較模板: years={','.join(ys)} items={want_items}") | |
| return self._finalize_sql(template, f"模板覆寫: {','.join(ys)} 年比較") | |
| # 動作意圖:count / list | |
| if any(kw in q_lower for kw in ['幾份', '份數', '份数', '多少', '數量', '總數', 'how many', 'count']): | |
| intents['action'] = 'count' | |
| if ("測試項目" in q) or ("test item" in q_lower): | |
| sql['select'].append("COUNT(jip.ItemCode) AS item_count") | |
| sql['log_parts'].append("測試項目總數") | |
| else: | |
| sql['select'].append("COUNT(DISTINCT jt.JobNo) AS report_count") | |
| sql['log_parts'].append("報告總數") | |
| elif any(kw in q_lower for kw in ['報告號碼', '報告清單', '列出報告', 'report number', 'list of reports']): | |
| intents['action'] = 'list' | |
| sql['select'].append("jt.JobNo, jt.ReportAuthorization") | |
| sql['order_by'].append("jt.ReportAuthorization DESC") | |
| sql['log_parts'].append("報告列表") | |
| # 時間意圖:年/月 | |
| ym = re.search(r'(\d{4})\s*年?', q) | |
| mm = re.search(r'(\d{1,2})\s*月', q) | |
| if ym: | |
| year = ym.group(1) | |
| sql['where'].append(f"strftime('%Y', jt.ReportAuthorization) = '{year}'") | |
| sql['log_parts'].append(f"{year}年") | |
| if mm: | |
| month = mm.group(1).zfill(2) | |
| sql['where'].append(f"strftime('%m', jt.ReportAuthorization) = '{month}'") | |
| sql['log_parts'].append(f"{month}月") | |
| # 實體意圖 | |
| if entity_match_data: | |
| if "TSR53SampleDescription" not in " ".join(sql['joins']): | |
| sql['joins'].append("JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo") | |
| entity_name, column_name = entity_match_data['name'], entity_match_data['column'] | |
| match_op = '=' if column_name.endswith('ID') else 'LIKE' | |
| entity_val = f"'%{entity_name}%'" if match_op == 'LIKE' else f"'{entity_name}'" | |
| collate = " COLLATE NOCASE" if match_op == 'LIKE' else "" | |
| sql['where'].append(f"{column_name} {match_op} {entity_val}{collate}") | |
| sql['log_parts'].append(entity_match_data['type'] + ":" + entity_name) | |
| if intents.get('action') == 'list': | |
| sql['select'].append("sd.BuyerName") | |
| # 評級意圖 | |
| if ('fail' in q_lower) or ('失敗' in q_lower): | |
| if "TSR53SampleDescription" not in " ".join(sql['joins']): | |
| sql['joins'].append("JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo") | |
| sql['where'].append("sd.OverallRating = 'Fail'") | |
| sql['log_parts'].append("Fail") | |
| elif ('pass' in q_lower) or ('通過' in q_lower): | |
| if "TSR53SampleDescription" not in " ".join(sql['joins']): | |
| sql['joins'].append("JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo") | |
| sql['where'].append("sd.OverallRating = 'Pass'") | |
| sql['log_parts'].append("Pass") | |
| # 實驗組 (LabGroup) | |
| lab_group_mapping = {'A': 'TA', 'B': 'TB', 'C': 'TC', 'D': 'TD', 'E': 'TE', 'Y': 'TY'} | |
| lgm = re.search(r'([A-Z]{1,2})組', q, re.IGNORECASE) | |
| if lgm: | |
| user_group = lgm.group(1).upper() | |
| db_group = lab_group_mapping.get(user_group, user_group) | |
| sql['joins'].append("JOIN JobItemsInProgress AS jip ON jt.JobNo = jip.JobNo") | |
| sql['where'].append(f"jip.LabGroup = '{db_group}'") | |
| sql['log_parts'].append(f"{user_group}組(->{db_group})") | |
| # 若動作已決定,組裝模板 SQL | |
| if 'action' in intents: | |
| sql['from'] = "FROM JobTimeline AS jt" | |
| if sql['where']: | |
| sql['where'].insert(0, "jt.ReportAuthorization IS NOT NULL") | |
| select_clause = "SELECT " + ", ".join(sorted(list(set(sql['select'])))) if sql['select'] else "SELECT *" | |
| from_clause = sql['from'] | |
| joins_clause = " ".join(sql['joins']) | |
| where_clause = ("WHERE " + " AND ".join(sql['where'])) if sql['where'] else "" | |
| orderby_clause = ("ORDER BY " + ", ".join(sql['order_by'])) if sql['order_by'] else "" | |
| template_sql = f"{select_clause} {from_clause} {joins_clause} {where_clause} {orderby_clause};" | |
| query_log = " ".join(sql['log_parts']) | |
| self._log(f"🔄 偵測到組合意圖【{query_log}】,啟用動態模板。") | |
| return self._finalize_sql(template_sql, f"模板覆寫: {query_log} 查詢") | |
| # 第二層:解析 AI 輸出並修正 | |
| self._log("未觸發任何模板,嘗試解析並修正 AI 輸出…") | |
| parsed_sql = parse_sql_from_response(raw_response) | |
| if not parsed_sql: | |
| # 嘗試救援:模型可能省略了開頭的 SELECT(因為 Prompt 已種子 SELECT) | |
| resp = (raw_response or '').strip() | |
| if resp and not resp.upper().startswith('SELECT') and re.search(r'\bFROM\b', resp, re.IGNORECASE): | |
| self._log("嘗試自動補上 SELECT 以修復不完整輸出", "INFO") | |
| salvage_sql = 'SELECT ' + resp | |
| parsed_sql = parse_sql_from_response(salvage_sql) or salvage_sql | |
| if not parsed_sql: | |
| self._log("模型輸出非 SQL,啟用嚴格模式重生一次…") | |
| parsed_sql = self._regenerate_sql_strict(q) | |
| if not parsed_sql: | |
| self._log(f"❌ 未能從模型回應中解析出任何 SQL。原始回應: {raw_response}", "ERROR") | |
| return None, f"無法解析SQL。原始回應:\n{raw_response}" | |
| self._log(f"📊 解析出的原始 SQL: {parsed_sql}", "DEBUG") | |
| normalized_sql = sanitize_sql(parsed_sql) | |
| if normalized_sql != parsed_sql: | |
| self._log(f"🧹 清理後 SQL: {normalized_sql}", "DEBUG") | |
| fixed_sql = " " + normalized_sql.strip() + " " | |
| fixes_applied = [] | |
| # 方言修正 | |
| dialect_corrections = {r'YEAR\s*\(([^)]+)\)': r"strftime('%Y', \1)"} | |
| for pat, rep in dialect_corrections.items(): | |
| if re.search(pat, fixed_sql, re.IGNORECASE): | |
| fixed_sql = re.sub(pat, rep, fixed_sql, flags=re.IGNORECASE) | |
| fixes_applied.append(f"修正方言: {pat}") | |
| # Schema 名稱修正(常見別名 => 真實欄位) | |
| schema_map = { | |
| 'TSR53Report':'TSR53SampleDescription', | |
| 'TSR53InvoiceReportNo':'JobNo', | |
| 'TSR53ReportNo':'JobNo', | |
| 'TSR53InvoiceNo':'JobNo', | |
| 'TSR53InvoiceCreditNoteNo':'InvoiceCreditNoteNo', | |
| 'TSR53InvoiceLocalAmount':'LocalAmount', | |
| 'Status':'OverallRating', | |
| 'ReportStatus':'OverallRating' | |
| } | |
| for wrong, correct in schema_map.items(): | |
| pat = r'\b' + re.escape(wrong) + r'\b' | |
| if re.search(pat, fixed_sql, re.IGNORECASE): | |
| fixed_sql = re.sub(pat, correct, fixed_sql, flags=re.IGNORECASE) | |
| fixes_applied.append(f"映射 Schema: '{wrong}' -> '{correct}'") | |
| # 若沒有 FROM,補上預設資料來源 | |
| if re.search(r"\bSELECT\b", fixed_sql, re.IGNORECASE) and not re.search(r"\bFROM\b", fixed_sql, re.IGNORECASE): | |
| if re.search(r"COUNT\s*\(\s*\*\s*\)", fixed_sql, re.IGNORECASE): | |
| fixed_sql = " SELECT COUNT(DISTINCT jt.JobNo) FROM JobTimeline AS jt WHERE jt.ReportAuthorization IS NOT NULL " | |
| fixes_applied.append("補上預設 FROM JobTimeline (COUNT 專用)") | |
| else: | |
| fixed_sql = " SELECT * FROM JobTimeline AS jt WHERE jt.ReportAuthorization IS NOT NULL " | |
| fixes_applied.append("補上預設 FROM JobTimeline") | |
| status = "AI 生成並成功修正" if fixes_applied else "AI 生成且無需修正" | |
| return self._finalize_sql(fixed_sql, status) | |
| def _generate_fallback_sql(self, prompt: str) -> str: | |
| """當模型不可用時的備用 SQL 生成""" | |
| prompt_lower = (prompt or "").lower() | |
| # 統計類:優先使用 JobTimeline.ReportAuthorization,避免不存在的 completed_time 欄位 | |
| if ("統計" in prompt) or ("數量" in prompt) or ("多少" in prompt) or ("count" in prompt_lower): | |
| if ("月" in prompt) or ("per month" in prompt_lower) or ("monthly" in prompt_lower): | |
| return ( | |
| "SELECT strftime('%Y-%m', jt.ReportAuthorization) AS month, " | |
| "COUNT(DISTINCT jt.JobNo) AS count " | |
| "FROM JobTimeline AS jt " | |
| "WHERE jt.ReportAuthorization IS NOT NULL " | |
| "GROUP BY month ORDER BY month;" | |
| ) | |
| elif ("客戶" in prompt) or ("buyer" in prompt_lower) or ("applicant" in prompt_lower): | |
| return ( | |
| "SELECT sd.ApplicantName AS applicant, COUNT(DISTINCT jt.JobNo) AS count " | |
| "FROM JobTimeline AS jt " | |
| "JOIN TSR53SampleDescription AS sd ON jt.JobNo = sd.JobNo " | |
| "WHERE jt.ReportAuthorization IS NOT NULL " | |
| "GROUP BY sd.ApplicantName ORDER BY count DESC;" | |
| ) | |
| else: | |
| return ( | |
| "SELECT COUNT(DISTINCT jt.JobNo) AS total_count " | |
| "FROM JobTimeline AS jt " | |
| "WHERE jt.ReportAuthorization IS NOT NULL;" | |
| ) | |
| # 金額彙總 | |
| if ("金額" in prompt) or ("總額" in prompt) or ("amount" in prompt_lower) or ("sum" in prompt_lower): | |
| return "SELECT SUM(LocalAmount) AS total_amount FROM TSR53Invoice;" | |
| # 評級分布 | |
| if ("評級" in prompt) or ("rating" in prompt_lower) or ("pass" in prompt_lower) or ("fail" in prompt_lower): | |
| return "SELECT OverallRating AS rating, COUNT(*) AS count FROM TSR53SampleDescription GROUP BY OverallRating;" | |
| # 通用後備:最近 10 筆報告 | |
| return ( | |
| "SELECT jt.JobNo, jt.ReportAuthorization " | |
| "FROM JobTimeline AS jt " | |
| "WHERE jt.ReportAuthorization IS NOT NULL " | |
| "ORDER BY jt.ReportAuthorization DESC LIMIT 10;" | |
| ) | |
| def process_question(self, question: str) -> Tuple[str, str]: | |
| """處理使用者問題""" | |
| # 檢查緩存 | |
| if question in self.query_cache: | |
| self._log("使用緩存結果") | |
| return self.query_cache[question] | |
| self.log_history = [] | |
| self._log(f"處理問題: {question}") | |
| self._log(check_memory_usage()) | |
| # 0. 規則先行(命中則直接返回) | |
| rb = self._rule_based_sql(question) | |
| if rb: | |
| self._log("規則命中,直接生成 SQL(跳過 LLM)") | |
| self._log(f"最終 SQL: {rb}") | |
| result = (rb, "規則生成") | |
| self.query_cache[question] = result | |
| gc.collect() | |
| return result | |
| # 1. 檢索相似範例 | |
| self._log("尋找相似範例...") | |
| examples = self.find_most_similar(question, FEW_SHOT_EXAMPLES_COUNT) | |
| if examples: | |
| self._log(f"找到 {len(examples)} 個相似範例") | |
| # 2. 建立提示詞 | |
| self._log("建立 Prompt...") | |
| prompt = self._build_prompt(question, examples) | |
| # 3. 生成 AI 回應 | |
| self._log("開始生成 AI 回應...") | |
| response = self.huggingface_api_call(prompt) | |
| # 4. 驗證/修正 SQL | |
| fixed_sql, status_message = self._validate_and_fix_sql(question, response) | |
| if not fixed_sql: | |
| fixed_sql = "SELECT '未能生成有效的SQL,請嘗試換個問題描述';" | |
| status_message = status_message or "生成失敗" | |
| self._log(f"最終 SQL: {fixed_sql}") | |
| result = (fixed_sql, status_message) | |
| # 緩存結果 | |
| self.query_cache[question] = result | |
| # 清理內存 | |
| gc.collect() | |
| return result | |
| # ==================== Gradio 介面與 API ==================== | |
| print("正在初始化 Text-to-SQL 系統...") | |
| text_to_sql_system = TextToSQLSystem() | |
| def process_query(q: str, prompt_override: str = ""): | |
| if not (q or prompt_override).strip(): | |
| return "", "等待輸入", "請輸入問題或提供 prompt_override" | |
| # 若提供 prompt_override: | |
| if prompt_override and prompt_override.strip(): | |
| po = prompt_override.strip() | |
| # 如果 override 本身就是 SQL,直接回傳 | |
| if po.upper().startswith("SELECT"): | |
| if not po.strip().endswith(";"): | |
| po = po.strip() + ";" | |
| text_to_sql_system._log("使用 prompt_override 直接回傳 SQL") | |
| logs = "\n".join(text_to_sql_system.log_history[-15:]) | |
| return po, "override", logs | |
| # 否則當作完整 prompt 丟給 LLM | |
| text_to_sql_system._log("使用 prompt_override 直接調用 LLM") | |
| constrained_po = ( | |
| po.rstrip() | |
| + "\n\nReturn only the final SQL query in a fenced code block (```sql ... ```). " | |
| + "Do not output narration, bullets, or explanations. The SQL must start with SELECT and end with a semicolon." | |
| ) | |
| response = text_to_sql_system.huggingface_api_call(constrained_po) | |
| fixed_sql, status_message = text_to_sql_system._validate_and_fix_sql(q or "", response) | |
| if not fixed_sql: | |
| fixed_sql = text_to_sql_system._generate_fallback_sql(po) | |
| status_message = status_message or "override 回退" | |
| text_to_sql_system._log(f"最終 SQL: {fixed_sql}") | |
| logs = "\n".join(text_to_sql_system.log_history[-15:]) | |
| return fixed_sql, "override", logs | |
| sql, status = text_to_sql_system.process_question(q) | |
| logs = "\n".join(text_to_sql_system.log_history[-15:]) # 顯示最後15條日誌 | |
| return sql, status, logs | |
| # 範例問題 | |
| examples = [ | |
| "2024年每月完成多少份報告?", | |
| "統計各種評級(Pass/Fail)的分布情況", | |
| "找出總金額最高的10個工作單", | |
| "哪些客戶的工作單數量最多?", | |
| "A組昨天完成了多少個測試項目?" | |
| ] | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL 智能助手 (HF Space)") as demo: | |
| gr.Markdown("# Text-to-SQL 智能助手 (Hugging Face Space)") | |
| gr.Markdown("輸入自然語言問題,自動生成SQL查詢語句。使用 /tmp 暫存,每次啟動重新下載模型。支援桌面端透過 /predict API 呼叫。") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| inp = gr.Textbox(lines=3, label="您的問題", placeholder="例如:2024年每月完成多少份報告?") | |
| btn = gr.Button("生成 SQL", variant="primary") | |
| status = gr.Textbox(label="狀態", interactive=False) | |
| # 隱藏的 prompt_override 供桌面端呼叫 | |
| prompt_override = gr.Textbox(label="prompt_override", visible=False) | |
| with gr.Column(scale=3): | |
| sql_out = gr.Code(label="生成的 SQL", language="sql", lines=8) | |
| with gr.Accordion("處理日誌", open=False): | |
| logs = gr.Textbox(lines=10, label="日誌", interactive=False) | |
| # 範例區 | |
| gr.Examples( | |
| examples=examples, | |
| inputs=inp, | |
| label="點擊試用範例問題" | |
| ) | |
| # 綁定事件 | |
| btn.click(process_query, inputs=[inp, prompt_override], outputs=[sql_out, status, logs], api_name="/predict") | |
| inp.submit(process_query, inputs=[inp, prompt_override], outputs=[sql_out, status, logs]) | |
| # ========== 使用 FastAPI 掛載,提供 /health ========== | |
| _fastapi_app = FastAPI() | |
| def health_endpoint(): | |
| endpoints = [] | |
| try: | |
| cfg = getattr(demo, "config", None) | |
| if isinstance(cfg, dict): | |
| deps = cfg.get("dependencies") or [] | |
| for dep in deps: | |
| endpoints.append({ | |
| "api_name": dep.get("api_name"), | |
| "fn_index": dep.get("fn_index"), | |
| "inputs_count": len(dep.get("inputs") or []), | |
| "outputs_count": len(dep.get("outputs") or []), | |
| }) | |
| except Exception: | |
| pass | |
| if not endpoints: | |
| endpoints.append({ | |
| "api_name": "/predict", | |
| "fn_index": None, | |
| "inputs_count": 2, | |
| "outputs_count": 3, | |
| }) | |
| env_info = { | |
| "USE_GPU": USE_GPU, | |
| "DEVICE": DEVICE, | |
| "N_GPU_LAYERS": N_GPU_LAYERS, | |
| "THREADS": THREADS, | |
| "CTX": CTX, | |
| "MAX_TOKENS": MAX_TOKENS, | |
| "FEW_SHOT_EXAMPLES_COUNT": FEW_SHOT_EXAMPLES_COUNT, | |
| "ENABLE_INDEX": ENABLE_INDEX, | |
| "EMBED_BATCH": EMBED_BATCH, | |
| "N_BATCH": N_BATCH, | |
| "GGUF_REPO_ID": GGUF_REPO_ID, | |
| "GGUF_FILENAME": GGUF_FILENAME, | |
| } | |
| server_info = { | |
| "time": get_current_time(), | |
| "gradio_version": getattr(gr, "__version__", "unknown"), | |
| "pid": os.getpid(), | |
| } | |
| return {"status": "ok", "endpoints": endpoints, "env": env_info, "server": server_info} | |
| # 將 Gradio Blocks 掛載到 FastAPI 的根路徑 | |
| app = gr.mount_gradio_app(_fastapi_app, demo, path="/") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |