Allex21 commited on
Commit
c4646a0
·
verified ·
1 Parent(s): 3d1e8cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -16
app.py CHANGED
@@ -35,14 +35,20 @@ class LoRAImageTrainer:
35
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  self.training_jobs = {}
37
  self.models_cache = {}
38
- Path("./jobs").mkdir(exist_ok=True)
39
- Path("./lora_models").mkdir(exist_ok=True)
 
 
40
 
41
  def _save_job_state(self, job_id: str):
42
  """Salva o estado do job em disco."""
43
  job_file = Path(f"./jobs/{job_id}.json")
44
- with open(job_file, "w") as f:
45
- json.dump(self.training_jobs[job_id], f, indent=2, default=str)
 
 
 
 
46
 
47
  def _load_job_state(self, job_id: str) -> Optional[Dict]:
48
  """Carrega o estado do job do disco."""
@@ -50,9 +56,13 @@ class LoRAImageTrainer:
50
  if job_file.exists():
51
  try:
52
  with open(job_file, "r") as f:
53
- return json.load(f)
 
 
54
  except Exception as e:
55
  logger.error(f"Erro ao carregar job {job_id}: {e}")
 
 
56
  return None
57
 
58
  def get_training_status(self, job_id: str) -> Dict[str, Any]:
@@ -177,8 +187,10 @@ class LoRAImageTrainer:
177
 
178
  self.training_jobs[job_id]["status"] = "loading_model"
179
  self.training_jobs[job_id]["progress"] = 5
180
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Carregando modelo base: {model_name}")
 
181
  self._save_job_state(job_id)
 
182
 
183
  pipeline = self.load_base_model(model_name)
184
  unet = pipeline.unet
@@ -191,9 +203,10 @@ class LoRAImageTrainer:
191
  text_encoder.requires_grad_(False)
192
  vae.requires_grad_(False)
193
 
194
- # ✅ CORREÇÃO 1: REMOVER ADAPTADOR EXISTENTE
195
  if hasattr(unet, "peft_config") and "default" in unet.peft_config:
196
  unet.delete_adapter("default")
 
197
 
198
  lora_config = LoraConfig(
199
  r=r,
@@ -206,16 +219,19 @@ class LoRAImageTrainer:
206
  unet.add_adapter(lora_config, adapter_name="default")
207
  unet.set_adapter("default")
208
 
209
- # ✅ CORREÇÃO 2: ATIVAR APENAS PARÂMETROS DO LORA
210
  unet.requires_grad_(False)
 
211
  for name, param in unet.named_parameters():
212
  if "lora_" in name:
213
  param.requires_grad = True
 
 
 
214
 
215
  unet.train()
216
  unet.to(self.device)
217
 
218
- # Otimizador só nos parâmetros que requerem gradiente
219
  optimizer = torch.optim.AdamW([p for p in unet.parameters() if p.requires_grad], lr=learning_rate)
220
 
221
  self.training_jobs[job_id]["status"] = "preparing_data"
@@ -232,8 +248,10 @@ class LoRAImageTrainer:
232
  current_step = 0
233
 
234
  self.training_jobs[job_id]["status"] = "training"
235
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Iniciando treinamento real...")
 
236
  self._save_job_state(job_id)
 
237
 
238
  for epoch in range(num_epochs):
239
  for item in dataset:
@@ -271,8 +289,9 @@ class LoRAImageTrainer:
271
 
272
  if current_step % max(1, len(dataset)//2) == 0:
273
  log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
274
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - {log_msg}")
275
  self._save_job_state(job_id)
 
276
 
277
  self.training_jobs[job_id]["status"] = "saving"
278
  self.training_jobs[job_id]["progress"] = 95
@@ -281,7 +300,6 @@ class LoRAImageTrainer:
281
  output_dir = f"./lora_models/{job_id}"
282
  os.makedirs(output_dir, exist_ok=True)
283
 
284
- # ✅ SALVAR APENAS O LORA
285
  unet.save_pretrained(
286
  output_dir,
287
  safe_serialization=True,
@@ -319,10 +337,10 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
319
  self.training_jobs[job_id]["progress"] = 100
320
  self.training_jobs[job_id]["model_path"] = output_dir
321
  self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
322
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ✅ Treinamento concluído! LoRA salvo em {output_dir}")
 
323
  self._save_job_state(job_id)
324
-
325
- logger.info(f"Treinamento LoRA concluído para job {job_id}")
326
 
327
  except Exception as e:
328
  error_msg = f"Erro no treinamento: {str(e)}"
@@ -386,7 +404,8 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
386
  "r": config.get("r", "Unknown"),
387
  "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
388
  })
389
- except:
 
390
  models.append({
391
  "id": model_dir.name,
392
  "path": str(model_dir),
@@ -519,6 +538,11 @@ def create_gradio_interface():
519
  return interface
520
 
521
  if __name__ == "__main__":
 
 
 
 
 
522
  interface = create_gradio_interface()
523
  interface.launch(
524
  server_name="0.0.0.0",
 
35
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  self.training_jobs = {}
37
  self.models_cache = {}
38
+ # ✅ Garantir que as pastas existam no diretório atual
39
+ os.makedirs("./lora_models", exist_ok=True)
40
+ os.makedirs("./jobs", exist_ok=True)
41
+ logger.info("Pastas ./lora_models e ./jobs criadas com sucesso.")
42
 
43
  def _save_job_state(self, job_id: str):
44
  """Salva o estado do job em disco."""
45
  job_file = Path(f"./jobs/{job_id}.json")
46
+ try:
47
+ with open(job_file, "w") as f:
48
+ json.dump(self.training_jobs[job_id], f, indent=2, default=str)
49
+ logger.info(f"Estado do job {job_id} salvo em disco.")
50
+ except Exception as e:
51
+ logger.error(f"Erro ao salvar job {job_id}: {e}")
52
 
53
  def _load_job_state(self, job_id: str) -> Optional[Dict]:
54
  """Carrega o estado do job do disco."""
 
56
  if job_file.exists():
57
  try:
58
  with open(job_file, "r") as f:
59
+ loaded_data = json.load(f)
60
+ logger.info(f"Estado do job {job_id} carregado do disco.")
61
+ return loaded_data
62
  except Exception as e:
63
  logger.error(f"Erro ao carregar job {job_id}: {e}")
64
+ else:
65
+ logger.warning(f"Arquivo do job {job_id} não encontrado em disco.")
66
  return None
67
 
68
  def get_training_status(self, job_id: str) -> Dict[str, Any]:
 
187
 
188
  self.training_jobs[job_id]["status"] = "loading_model"
189
  self.training_jobs[job_id]["progress"] = 5
190
+ log_msg = f"{datetime.now().strftime('%H:%M:%S')} - Carregando modelo base: {model_name}"
191
+ self.training_jobs[job_id]["logs"].append(log_msg)
192
  self._save_job_state(job_id)
193
+ logger.info(log_msg)
194
 
195
  pipeline = self.load_base_model(model_name)
196
  unet = pipeline.unet
 
203
  text_encoder.requires_grad_(False)
204
  vae.requires_grad_(False)
205
 
206
+ # ✅ Remover adaptador existente
207
  if hasattr(unet, "peft_config") and "default" in unet.peft_config:
208
  unet.delete_adapter("default")
209
+ logger.info("Adaptador 'default' removido com sucesso.")
210
 
211
  lora_config = LoraConfig(
212
  r=r,
 
219
  unet.add_adapter(lora_config, adapter_name="default")
220
  unet.set_adapter("default")
221
 
222
+ # ✅ Ativar apenas parâmetros do LoRA
223
  unet.requires_grad_(False)
224
+ trainable_params = 0
225
  for name, param in unet.named_parameters():
226
  if "lora_" in name:
227
  param.requires_grad = True
228
+ trainable_params += 1
229
+
230
+ logger.info(f"Número de parâmetros treináveis (LoRA): {trainable_params}")
231
 
232
  unet.train()
233
  unet.to(self.device)
234
 
 
235
  optimizer = torch.optim.AdamW([p for p in unet.parameters() if p.requires_grad], lr=learning_rate)
236
 
237
  self.training_jobs[job_id]["status"] = "preparing_data"
 
248
  current_step = 0
249
 
250
  self.training_jobs[job_id]["status"] = "training"
251
+ log_msg = f"{datetime.now().strftime('%H:%M:%S')} - Iniciando treinamento real..."
252
+ self.training_jobs[job_id]["logs"].append(log_msg)
253
  self._save_job_state(job_id)
254
+ logger.info(log_msg)
255
 
256
  for epoch in range(num_epochs):
257
  for item in dataset:
 
289
 
290
  if current_step % max(1, len(dataset)//2) == 0:
291
  log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
292
+ self.training_jobs[job_id]["logs"].append(log_msg)
293
  self._save_job_state(job_id)
294
+ logger.info(log_msg)
295
 
296
  self.training_jobs[job_id]["status"] = "saving"
297
  self.training_jobs[job_id]["progress"] = 95
 
300
  output_dir = f"./lora_models/{job_id}"
301
  os.makedirs(output_dir, exist_ok=True)
302
 
 
303
  unet.save_pretrained(
304
  output_dir,
305
  safe_serialization=True,
 
337
  self.training_jobs[job_id]["progress"] = 100
338
  self.training_jobs[job_id]["model_path"] = output_dir
339
  self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
340
+ log_msg = f"{datetime.now().strftime('%H:%M:%S')} - ✅ Treinamento concluído! LoRA salvo em {output_dir}"
341
+ self.training_jobs[job_id]["logs"].append(log_msg)
342
  self._save_job_state(job_id)
343
+ logger.info(log_msg)
 
344
 
345
  except Exception as e:
346
  error_msg = f"Erro no treinamento: {str(e)}"
 
404
  "r": config.get("r", "Unknown"),
405
  "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
406
  })
407
+ except Exception as e:
408
+ logger.error(f"Erro ao ler config de {model_dir.name}: {e}")
409
  models.append({
410
  "id": model_dir.name,
411
  "path": str(model_dir),
 
538
  return interface
539
 
540
  if __name__ == "__main__":
541
+ # ✅ Garantir que os diretórios existam
542
+ os.makedirs("./lora_models", exist_ok=True)
543
+ os.makedirs("./jobs", exist_ok=True)
544
+ logger.info("Aplicação iniciada. Diretórios verificados.")
545
+
546
  interface = create_gradio_interface()
547
  interface.launch(
548
  server_name="0.0.0.0",