Paul720810 commited on
Commit
6fcaae4
·
verified ·
1 Parent(s): ccd921a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -65
app.py CHANGED
@@ -21,6 +21,9 @@ DATASET_REPO_ID = "Paul720810/Text-to-SQL-Softline"
21
  GGUF_REPO_ID = "Paul720810/gguf-models"
22
  GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q4_k_m.gguf"
23
 
 
 
 
24
  FEW_SHOT_EXAMPLES_COUNT = 1
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
@@ -85,62 +88,176 @@ class TextToSQLSystem:
85
  self._load_gguf_model()
86
 
87
  self._log("✅ 系統初始化完成")
 
88
  def _load_gguf_model(self):
89
- """載入 GGUF 模型並處理錯誤"""
 
90
  try:
91
  self._log("載入 GGUF 模型...")
92
-
93
- # 強制重新下載模型
94
  model_path = hf_hub_download(
95
  repo_id=GGUF_REPO_ID,
96
  filename=GGUF_FILENAME,
97
  repo_type="dataset",
98
- force_download=True # 強制重新下載
99
  )
100
 
101
- # 使用驗證方法檢查檔案
102
- if not self._validate_model_file(model_path):
103
- self._log("❌ 模型檔案驗證失敗,嘗試重新下載", "ERROR")
104
- # 刪除損壞的檔案並重新下載
105
- if os.path.exists(model_path):
106
- os.remove(model_path)
107
- model_path = hf_hub_download(
108
- repo_id=GGUF_REPO_ID,
109
- filename=GGUF_FILENAME,
110
- repo_type="dataset",
111
- force_download=True
112
- )
113
-
114
- # 再次驗證
115
- if not self._validate_model_file(model_path):
116
- raise ValueError("重新下載後檔案仍然無效")
117
-
118
- # 使用更保守的參數載入模型
119
  self.llm = Llama(
120
  model_path=model_path,
121
- n_ctx=512, # 減少上下文長度
122
- n_threads=4, # 固定線程數
123
- n_batch=128, # 減少批次大小
124
- verbose=False, # 關閉詳細輸出
125
- use_mmap=True, # 使用記憶體映射
126
- use_mlock=False, # 不鎖定記憶體
127
- n_gpu_layers=0 # 強制使用 CPU
128
  )
129
 
130
- # 測試模型是否能正常生成
131
  test_output = self.llm("SELECT", max_tokens=5, temperature=0.1)
132
- if not test_output or 'choices' not in test_output:
133
- raise RuntimeError("模型載入後無法正常生成")
134
-
135
- self._log("✅ GGUF 模型載入並測試成功")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  except Exception as e:
138
- self._log(f"❌ GGUF 模型載入失敗: {str(e)}", "ERROR")
139
- self._log("嘗試使用替代方案...", "INFO")
140
  self.llm = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- # 可以在這裡添加使用其他模型的邏輯
143
- # 例如使用 Hugging Face Transformers 的備用方案
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  def _load_gguf_model_fallback(self, model_path):
146
  """備用載入方式"""
@@ -337,33 +454,6 @@ class TextToSQLSystem:
337
 
338
  return prompt
339
 
340
- def huggingface_api_call(self, prompt: str) -> str:
341
- """使用 GGUF 模型生成或提供替代方案"""
342
- if self.llm is None:
343
- # 返回基於規則的簡單 SQL 生成
344
- return self._generate_fallback_sql(prompt)
345
-
346
- try:
347
- if len(prompt) > 1500: # 縮短提示長度
348
- prompt = prompt[:1500] + "..."
349
-
350
- output = self.llm(
351
- prompt,
352
- max_tokens=128, # 減少最大 token 數
353
- temperature=0.0, # 使用確定性生成
354
- top_p=0.95,
355
- stop=["</s>", "```", "\n\n", "問題:"], # 添加更多停止詞
356
- echo=False
357
- )
358
-
359
- if output and 'choices' in output and output['choices']:
360
- return output["choices"][0]["text"].strip()
361
- else:
362
- return "模型生成失敗"
363
-
364
- except Exception as e:
365
- self._log(f"❌ 生成失敗: {e}", "ERROR")
366
- return self._generate_fallback_sql(prompt)
367
 
368
  def _generate_fallback_sql(self, prompt: str) -> str:
369
  """當模型不可用時的備用 SQL 生成"""
 
21
  GGUF_REPO_ID = "Paul720810/gguf-models"
22
  GGUF_FILENAME = "qwen2.5-coder-1.5b-sql-finetuned.q4_k_m.gguf"
23
 
24
+ # 添加這一行:你的原始微調模型路徑
25
+ FINETUNED_MODEL_PATH = "Paul720810/qwen2.5-coder-1.5b-sql-finetuned" # ← 新增這行
26
+
27
  FEW_SHOT_EXAMPLES_COUNT = 1
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
 
88
  self._load_gguf_model()
89
 
90
  self._log("✅ 系統初始化完成")
91
+
92
  def _load_gguf_model(self):
93
+ """載入 GGUF 模型,失敗則使用 Transformers 備用方案"""
94
+ # 先嘗試原本的 GGUF 載入方式
95
  try:
96
  self._log("載入 GGUF 模型...")
 
 
97
  model_path = hf_hub_download(
98
  repo_id=GGUF_REPO_ID,
99
  filename=GGUF_FILENAME,
100
  repo_type="dataset",
101
+ force_download=True
102
  )
103
 
104
+ # 你原本的載入參數
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  self.llm = Llama(
106
  model_path=model_path,
107
+ n_ctx=512, # 減少上下文長度
108
+ n_threads=4, # 固定線程數
109
+ n_batch=128, # 減少批次大小
110
+ verbose=False, # 關閉詳細輸出
111
+ use_mmap=True, # 使用記憶體映射
112
+ use_mlock=False, # 不鎖定記憶體
113
+ n_gpu_layers=0 # 強制使用 CPU
114
  )
115
 
116
+ # 測試是否能正常生成
117
  test_output = self.llm("SELECT", max_tokens=5, temperature=0.1)
118
+ self._log("✅ GGUF 模型載入成功")
119
+ return
120
+
121
+ except Exception as e:
122
+ self._log(f"❌ GGUF 載入失敗: {e}", "ERROR")
123
+
124
+ # GGUF 失敗,使用 Transformers 載入你的微調模型
125
+ try:
126
+ self._log("改用 Transformers 載入微調模型...")
127
+ from transformers import AutoModelForCausalLM, AutoTokenizer
128
+ import torch
129
+
130
+ self.transformers_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL_PATH)
131
+ self.transformers_model = AutoModelForCausalLM.from_pretrained(
132
+ FINETUNED_MODEL_PATH,
133
+ torch_dtype=torch.float32,
134
+ device_map="cpu",
135
+ trust_remote_code=True
136
+ )
137
+
138
+ self.llm = "transformers" # 標記使用 transformers
139
+ self._log("✅ Transformers 模型載入成功")
140
+
141
+ except Exception as e:
142
+ self._log(f"❌ Transformers 載入也失敗: {e}", "ERROR")
143
+ self.llm = None
144
+
145
+ def _try_gguf_loading(self):
146
+ """嘗試載入 GGUF"""
147
+ try:
148
+ model_path = hf_hub_download(
149
+ repo_id=GGUF_REPO_ID,
150
+ filename=GGUF_FILENAME,
151
+ repo_type="dataset"
152
+ )
153
+
154
+ self.llm = Llama(
155
+ model_path=model_path,
156
+ n_ctx=512,
157
+ n_threads=4,
158
+ verbose=False,
159
+ n_gpu_layers=0
160
+ )
161
+
162
+ # 測試生成
163
+ test_result = self.llm("SELECT", max_tokens=5)
164
+ self._log("✅ GGUF 模型載入成功")
165
+ return True
166
+
167
+ except Exception as e:
168
+ self._log(f"GGUF 載入失敗: {e}", "WARNING")
169
+ return False
170
+
171
+ def _load_transformers_model(self):
172
+ """使用 Transformers 載入你的微調模型"""
173
+ try:
174
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
175
+ import torch
176
+
177
+ self._log(f"載入 Transformers 模型: {FINETUNED_MODEL_PATH}")
178
+
179
+ # 載入你的微調模型
180
+ self.transformers_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL_PATH)
181
+ self.transformers_model = AutoModelForCausalLM.from_pretrained(
182
+ FINETUNED_MODEL_PATH,
183
+ torch_dtype=torch.float32, # CPU 使用 float32
184
+ device_map="cpu", # 強制使用 CPU
185
+ trust_remote_code=True # Qwen 模型可能需要
186
+ )
187
+
188
+ # 創建生成管道
189
+ self.generation_pipeline = pipeline(
190
+ "text-generation",
191
+ model=self.transformers_model,
192
+ tokenizer=self.transformers_tokenizer,
193
+ device=-1, # CPU
194
+ max_length=512,
195
+ do_sample=True,
196
+ temperature=0.1,
197
+ top_p=0.9,
198
+ pad_token_id=self.transformers_tokenizer.eos_token_id
199
+ )
200
+
201
+ self.llm = "transformers" # 標記使用 transformers
202
+ self._log("✅ Transformers 模型載入成功")
203
 
204
  except Exception as e:
205
+ self._log(f"❌ Transformers 載入也失敗: {e}", "ERROR")
 
206
  self.llm = None
207
+
208
+ def huggingface_api_call(self, prompt: str) -> str:
209
+ """使用 GGUF 或 Transformers 生成"""
210
+ if self.llm is None:
211
+ return self._generate_fallback_sql(prompt)
212
+
213
+ try:
214
+ # 如果是 Transformers 模型
215
+ if self.llm == "transformers":
216
+ # 限制 prompt 長度
217
+ if len(prompt) > 1000:
218
+ prompt = prompt[:1000]
219
+
220
+ # 使用 Transformers 生成
221
+ inputs = self.transformers_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
222
+
223
+ with torch.no_grad():
224
+ outputs = self.transformers_model.generate(
225
+ inputs.input_ids,
226
+ attention_mask=inputs.attention_mask,
227
+ max_new_tokens=128,
228
+ temperature=0.1,
229
+ do_sample=True,
230
+ top_p=0.9,
231
+ pad_token_id=self.transformers_tokenizer.eos_token_id,
232
+ eos_token_id=self.transformers_tokenizer.eos_token_id
233
+ )
234
+
235
+ # 解碼生成的文本,只取新生成的部分
236
+ generated_text = self.transformers_tokenizer.decode(
237
+ outputs[0][inputs.input_ids.shape[1]:],
238
+ skip_special_tokens=True
239
+ )
240
+
241
+ return generated_text.strip()
242
 
243
+ # 如果是 GGUF 模型(你原本的代碼)
244
+ else:
245
+ if len(prompt) > 1800:
246
+ prompt = prompt[:1800] + "..."
247
+
248
+ output = self.llm(
249
+ prompt,
250
+ max_tokens=256,
251
+ temperature=0.1,
252
+ top_p=0.9,
253
+ stop=["</s>", "```", ";", "\n\n"],
254
+ echo=False
255
+ )
256
+ return output["choices"][0]["text"].strip()
257
+
258
+ except Exception as e:
259
+ self._log(f"❌ 生成失敗: {e}", "ERROR")
260
+ return f"生成失敗: {e}"
261
 
262
  def _load_gguf_model_fallback(self, model_path):
263
  """備用載入方式"""
 
454
 
455
  return prompt
456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
458
  def _generate_fallback_sql(self, prompt: str) -> str:
459
  """當模型不可用時的備用 SQL 生成"""