Allex21 commited on
Commit
3d1e8cc
·
verified ·
1 Parent(s): 5e10fc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -22
app.py CHANGED
@@ -35,8 +35,8 @@ class LoRAImageTrainer:
35
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  self.training_jobs = {}
37
  self.models_cache = {}
38
- # ✅ Criar pasta para persistência de jobs
39
  Path("./jobs").mkdir(exist_ok=True)
 
40
 
41
  def _save_job_state(self, job_id: str):
42
  """Salva o estado do job em disco."""
@@ -67,7 +67,6 @@ class LoRAImageTrainer:
67
  return {"error": "Job não encontrado"}
68
 
69
  def get_available_models(self) -> List[str]:
70
- """Retorna lista de modelos base disponíveis para treinamento LoRA."""
71
  return [
72
  "runwayml/stable-diffusion-v1-5",
73
  "stabilityai/stable-diffusion-2-1",
@@ -75,7 +74,6 @@ class LoRAImageTrainer:
75
  ]
76
 
77
  def load_base_model(self, model_name: str):
78
- """Carrega modelo base de difusão com otimizações para baixo uso de GPU."""
79
  try:
80
  if model_name in self.models_cache:
81
  return self.models_cache[model_name]
@@ -163,7 +161,6 @@ class LoRAImageTrainer:
163
  batch_size: int = 1,
164
  resolution: int = 512) -> None:
165
  try:
166
- # Inicializar job se não existir
167
  if job_id not in self.training_jobs:
168
  self.training_jobs[job_id] = {
169
  "id": job_id,
@@ -178,13 +175,11 @@ class LoRAImageTrainer:
178
  "completed_at": None
179
  }
180
 
181
- # Atualizar status
182
  self.training_jobs[job_id]["status"] = "loading_model"
183
  self.training_jobs[job_id]["progress"] = 5
184
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Carregando modelo base: {model_name}")
185
- self._save_job_state(job_id) # ✅ SALVAR ESTADO
186
 
187
- # Carregar modelo base
188
  pipeline = self.load_base_model(model_name)
189
  unet = pipeline.unet
190
  text_encoder = pipeline.text_encoder
@@ -196,6 +191,10 @@ class LoRAImageTrainer:
196
  text_encoder.requires_grad_(False)
197
  vae.requires_grad_(False)
198
 
 
 
 
 
199
  lora_config = LoraConfig(
200
  r=r,
201
  lora_alpha=lora_alpha,
@@ -206,14 +205,22 @@ class LoRAImageTrainer:
206
 
207
  unet.add_adapter(lora_config, adapter_name="default")
208
  unet.set_adapter("default")
 
 
 
 
 
 
 
209
  unet.train()
210
  unet.to(self.device)
211
 
212
- optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
 
213
 
214
  self.training_jobs[job_id]["status"] = "preparing_data"
215
  self.training_jobs[job_id]["progress"] = 20
216
- self._save_job_state(job_id) # ✅ SALVAR ESTADO
217
 
218
  def preprocess_image(image):
219
  image = np.array(image).astype(np.float32) / 255.0
@@ -226,7 +233,7 @@ class LoRAImageTrainer:
226
 
227
  self.training_jobs[job_id]["status"] = "training"
228
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Iniciando treinamento real...")
229
- self._save_job_state(job_id) # ✅ SALVAR ESTADO
230
 
231
  for epoch in range(num_epochs):
232
  for item in dataset:
@@ -265,15 +272,16 @@ class LoRAImageTrainer:
265
  if current_step % max(1, len(dataset)//2) == 0:
266
  log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
267
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - {log_msg}")
268
- self._save_job_state(job_id) # ✅ SALVAR ESTADO (opcional aqui para logs)
269
 
270
  self.training_jobs[job_id]["status"] = "saving"
271
  self.training_jobs[job_id]["progress"] = 95
272
- self._save_job_state(job_id) # ✅ SALVAR ESTADO
273
 
274
  output_dir = f"./lora_models/{job_id}"
275
  os.makedirs(output_dir, exist_ok=True)
276
 
 
277
  unet.save_pretrained(
278
  output_dir,
279
  safe_serialization=True,
@@ -312,7 +320,7 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
312
  self.training_jobs[job_id]["model_path"] = output_dir
313
  self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
314
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ✅ Treinamento concluído! LoRA salvo em {output_dir}")
315
- self._save_job_state(job_id) # ✅ SALVAR ESTADO FINAL
316
 
317
  logger.info(f"Treinamento LoRA concluído para job {job_id}")
318
 
@@ -323,7 +331,7 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
323
  self.training_jobs[job_id]["status"] = "error"
324
  self.training_jobs[job_id]["error"] = error_msg
325
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ❌ {error_msg}")
326
- self._save_job_state(job_id) # ✅ SALVAR ESTADO DE ERRO
327
 
328
  def start_training(self,
329
  model_name: str,
@@ -347,7 +355,6 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
347
  "completed_at": None
348
  }
349
 
350
- # ✅ Salvar estado inicial
351
  self._save_job_state(job_id)
352
 
353
  thread = threading.Thread(
@@ -512,14 +519,7 @@ def create_gradio_interface():
512
  return interface
513
 
514
  if __name__ == "__main__":
515
- # ✅ Criar diretórios necessários
516
- os.makedirs("./lora_models", exist_ok=True)
517
- os.makedirs("./jobs", exist_ok=True) # Pasta para persistência de jobs
518
-
519
- # Configurar interface
520
  interface = create_gradio_interface()
521
-
522
- # Lançar aplicação
523
  interface.launch(
524
  server_name="0.0.0.0",
525
  server_port=7860,
 
35
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  self.training_jobs = {}
37
  self.models_cache = {}
 
38
  Path("./jobs").mkdir(exist_ok=True)
39
+ Path("./lora_models").mkdir(exist_ok=True)
40
 
41
  def _save_job_state(self, job_id: str):
42
  """Salva o estado do job em disco."""
 
67
  return {"error": "Job não encontrado"}
68
 
69
  def get_available_models(self) -> List[str]:
 
70
  return [
71
  "runwayml/stable-diffusion-v1-5",
72
  "stabilityai/stable-diffusion-2-1",
 
74
  ]
75
 
76
  def load_base_model(self, model_name: str):
 
77
  try:
78
  if model_name in self.models_cache:
79
  return self.models_cache[model_name]
 
161
  batch_size: int = 1,
162
  resolution: int = 512) -> None:
163
  try:
 
164
  if job_id not in self.training_jobs:
165
  self.training_jobs[job_id] = {
166
  "id": job_id,
 
175
  "completed_at": None
176
  }
177
 
 
178
  self.training_jobs[job_id]["status"] = "loading_model"
179
  self.training_jobs[job_id]["progress"] = 5
180
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Carregando modelo base: {model_name}")
181
+ self._save_job_state(job_id)
182
 
 
183
  pipeline = self.load_base_model(model_name)
184
  unet = pipeline.unet
185
  text_encoder = pipeline.text_encoder
 
191
  text_encoder.requires_grad_(False)
192
  vae.requires_grad_(False)
193
 
194
+ # ✅ CORREÇÃO 1: REMOVER ADAPTADOR EXISTENTE
195
+ if hasattr(unet, "peft_config") and "default" in unet.peft_config:
196
+ unet.delete_adapter("default")
197
+
198
  lora_config = LoraConfig(
199
  r=r,
200
  lora_alpha=lora_alpha,
 
205
 
206
  unet.add_adapter(lora_config, adapter_name="default")
207
  unet.set_adapter("default")
208
+
209
+ # ✅ CORREÇÃO 2: ATIVAR APENAS PARÂMETROS DO LORA
210
+ unet.requires_grad_(False)
211
+ for name, param in unet.named_parameters():
212
+ if "lora_" in name:
213
+ param.requires_grad = True
214
+
215
  unet.train()
216
  unet.to(self.device)
217
 
218
+ # Otimizador nos parâmetros que requerem gradiente
219
+ optimizer = torch.optim.AdamW([p for p in unet.parameters() if p.requires_grad], lr=learning_rate)
220
 
221
  self.training_jobs[job_id]["status"] = "preparing_data"
222
  self.training_jobs[job_id]["progress"] = 20
223
+ self._save_job_state(job_id)
224
 
225
  def preprocess_image(image):
226
  image = np.array(image).astype(np.float32) / 255.0
 
233
 
234
  self.training_jobs[job_id]["status"] = "training"
235
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Iniciando treinamento real...")
236
+ self._save_job_state(job_id)
237
 
238
  for epoch in range(num_epochs):
239
  for item in dataset:
 
272
  if current_step % max(1, len(dataset)//2) == 0:
273
  log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
274
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - {log_msg}")
275
+ self._save_job_state(job_id)
276
 
277
  self.training_jobs[job_id]["status"] = "saving"
278
  self.training_jobs[job_id]["progress"] = 95
279
+ self._save_job_state(job_id)
280
 
281
  output_dir = f"./lora_models/{job_id}"
282
  os.makedirs(output_dir, exist_ok=True)
283
 
284
+ # ✅ SALVAR APENAS O LORA
285
  unet.save_pretrained(
286
  output_dir,
287
  safe_serialization=True,
 
320
  self.training_jobs[job_id]["model_path"] = output_dir
321
  self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
322
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ✅ Treinamento concluído! LoRA salvo em {output_dir}")
323
+ self._save_job_state(job_id)
324
 
325
  logger.info(f"Treinamento LoRA concluído para job {job_id}")
326
 
 
331
  self.training_jobs[job_id]["status"] = "error"
332
  self.training_jobs[job_id]["error"] = error_msg
333
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ❌ {error_msg}")
334
+ self._save_job_state(job_id)
335
 
336
  def start_training(self,
337
  model_name: str,
 
355
  "completed_at": None
356
  }
357
 
 
358
  self._save_job_state(job_id)
359
 
360
  thread = threading.Thread(
 
519
  return interface
520
 
521
  if __name__ == "__main__":
 
 
 
 
 
522
  interface = create_gradio_interface()
 
 
523
  interface.launch(
524
  server_name="0.0.0.0",
525
  server_port=7860,