Paul720810 commited on
Commit
51d333a
·
verified ·
1 Parent(s): 07d3f8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -46
app.py CHANGED
@@ -61,15 +61,51 @@ def parse_sql_from_response(response_text: str) -> Optional[str]:
61
  return None
62
 
63
  # ==================== 核心 Text-to-SQL 系統類別 ====================
 
 
64
  class TextToSQLSystem:
65
  def __init__(self, model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2'):
66
  self.log_history = []
67
  self._log("初始化系統...")
68
- self.schema = self._load_schema() # 📌 自動載入 SQLite schema
 
 
69
  self.model = SentenceTransformer(model_name, device=DEVICE)
70
  self.dataset, self.corpus_embeddings = self._load_and_encode_dataset()
 
 
 
 
 
 
 
 
 
 
71
  self._log("✅ 系統初始化完成,已準備就緒。")
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def _log(self, message: str, level: str = "INFO"):
74
  self.log_history.append(format_log(message, level))
75
  print(format_log(message, level))
@@ -141,51 +177,6 @@ class TextToSQLSystem:
141
  })
142
  return similar_examples
143
 
144
- def huggingface_api_call(self, prompt: str) -> str:
145
- """呼叫 Hugging Face Inference API"""
146
- # === 修正開始 ===
147
- # 確保 API_URL 是一個乾淨的字串,不包含任何 Markdown "[ ]" 或其他特殊字元
148
- API_URL = "https://api-inference.huggingface.co/models/Paul720810/qwen2.5-coder-1.5b-sql-finetuned"
149
-
150
- # === 修正結束 ===
151
-
152
- headers = {"Authorization": f"Bearer {HF_TOKEN}"}
153
- payload = {
154
- "inputs": prompt,
155
- "parameters": {
156
- "max_new_tokens": 1024,
157
- "return_full_text": False
158
- }
159
- }
160
- try:
161
- # === 新增除錯日誌 ===
162
- # 在發送請求前,打印出最終要使用的 URL,以供檢查
163
- self._log(f"準備向 API 端點發送請求: {API_URL}")
164
-
165
- self._log("正在呼叫 Hugging Face API...")
166
- response = requests.post(API_URL, headers=headers, json=payload, timeout=90) # 延長超時時間
167
- response.raise_for_status() # 如果 API 回傳錯誤碼 (如 4xx, 5xx),會在此拋出例外
168
-
169
- self._log("✅ API 成功回應。")
170
- return response.json()[0]['generated_text']
171
-
172
- except requests.exceptions.RequestException as e:
173
- self._log(f"❌ API 呼叫失敗: {e}", "ERROR")
174
-
175
- # 嘗試解析回應內容,看是否是模型載入中的常見錯誤
176
- try:
177
- # 即使請求失敗,有時回應本文中仍有 JSON 錯誤訊息
178
- error_content = e.response.json() if e.response else {}
179
- if "error" in error_content and "estimated_time" in error_content["error"]:
180
- loading_time = error_content["error"]["estimated_time"]
181
- self._log(f" - 提示: 模型可能正在載入中,預計需要 {loading_time:.1f} 秒。請稍後重試。", "WARNING")
182
- return f"API 錯誤: 模型正在載入中,請稍後再試一次。"
183
- except (ValueError, AttributeError):
184
- # 如果回應不是 JSON 或沒有回應本文,就忽略
185
- pass
186
-
187
- return f"API 連線錯誤: {e}"
188
-
189
  # === 修改開始: 重寫核心處理邏輯 ===
190
  def _build_prompt_for_generation(self, user_question: str, examples: List[Dict]) -> str:
191
  """
 
61
  return None
62
 
63
  # ==================== 核心 Text-to-SQL 系統類別 ====================
64
+ from transformers import AutoModelForCausalLM, AutoTokenizer
65
+
66
  class TextToSQLSystem:
67
  def __init__(self, model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2'):
68
  self.log_history = []
69
  self._log("初始化系統...")
70
+
71
+ # 載入檢索模型
72
+ self.schema = self._load_schema()
73
  self.model = SentenceTransformer(model_name, device=DEVICE)
74
  self.dataset, self.corpus_embeddings = self._load_and_encode_dataset()
75
+
76
+ # ✅ 載入你自己的 Hugging Face 模型
77
+ self.generation_model_id = "Paul720810/qwen2.5-coder-1.5b-sql-finetuned"
78
+ self.tokenizer = AutoTokenizer.from_pretrained(self.generation_model_id)
79
+ self.generation_model = AutoModelForCausalLM.from_pretrained(
80
+ self.generation_model_id,
81
+ device_map="auto",
82
+ torch_dtype="auto"
83
+ )
84
+
85
  self._log("✅ 系統初始化完成,已準備就緒。")
86
 
87
+ def huggingface_api_call(self, prompt: str) -> str:
88
+ """直接使用本地載入的模型生成結果"""
89
+ try:
90
+ self._log("🧠 開始本地生成 SQL...")
91
+
92
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.generation_model.device)
93
+ outputs = self.generation_model.generate(
94
+ **inputs,
95
+ max_new_tokens=512,
96
+ do_sample=True,
97
+ temperature=0.7,
98
+ top_p=0.9
99
+ )
100
+ result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
101
+
102
+ self._log("✅ 本地生成完成。")
103
+ return result
104
+
105
+ except Exception as e:
106
+ self._log(f"❌ 本地生成失敗: {e}", "ERROR")
107
+ return f"本地生成錯誤: {e}"
108
+
109
  def _log(self, message: str, level: str = "INFO"):
110
  self.log_history.append(format_log(message, level))
111
  print(format_log(message, level))
 
177
  })
178
  return similar_examples
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  # === 修改開始: 重寫核心處理邏輯 ===
181
  def _build_prompt_for_generation(self, user_question: str, examples: List[Dict]) -> str:
182
  """