Allex21 commited on
Commit
5e10fc1
·
verified ·
1 Parent(s): f8c89ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -36
app.py CHANGED
@@ -35,7 +35,37 @@ class LoRAImageTrainer:
35
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  self.training_jobs = {}
37
  self.models_cache = {}
 
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def get_available_models(self) -> List[str]:
40
  """Retorna lista de modelos base disponíveis para treinamento LoRA."""
41
  return [
@@ -52,7 +82,6 @@ class LoRAImageTrainer:
52
 
53
  logger.info(f"Carregando modelo base: {model_name}")
54
 
55
- # Configurações para otimização de memória
56
  model_kwargs = {
57
  "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
58
  "use_safetensors": True,
@@ -60,7 +89,6 @@ class LoRAImageTrainer:
60
  "safety_checker": None,
61
  }
62
 
63
- # Carregar pipeline completo
64
  pipeline = StableDiffusionPipeline.from_pretrained(
65
  model_name,
66
  **model_kwargs
@@ -73,7 +101,6 @@ class LoRAImageTrainer:
73
  pipeline.enable_xformers_memory_efficient_attention()
74
  except:
75
  logger.warning("xformers não disponível")
76
- # ✅ ATIVAÇÃO DO GRADIENT CHECKPOINTING — REDUZ MEMÓRIA EM ATÉ 60%
77
  pipeline.unet.enable_gradient_checkpointing()
78
 
79
  self.models_cache[model_name] = pipeline
@@ -84,7 +111,6 @@ class LoRAImageTrainer:
84
  raise e
85
 
86
  def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:
87
- """Prepara dataset de imagens para treinamento."""
88
  dataset = []
89
 
90
  for img_path, caption in zip(image_files, captions):
@@ -105,7 +131,6 @@ class LoRAImageTrainer:
105
  return dataset
106
 
107
  def resize_image(self, image: Image.Image, target_size: int) -> Image.Image:
108
- """Redimensiona imagem mantendo aspect ratio e fazendo crop central."""
109
  width, height = image.size
110
 
111
  if width > height:
@@ -137,14 +162,28 @@ class LoRAImageTrainer:
137
  learning_rate: float = 1e-4,
138
  batch_size: int = 1,
139
  resolution: int = 512) -> None:
140
- """TREINAMENTO REAL DE LoRA PARA IMAGENS - VERSÃO FINAL OTIMIZADA."""
141
-
142
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  # Atualizar status
144
  self.training_jobs[job_id]["status"] = "loading_model"
145
  self.training_jobs[job_id]["progress"] = 5
146
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Carregando modelo base: {model_name}")
147
-
 
148
  # Carregar modelo base
149
  pipeline = self.load_base_model(model_name)
150
  unet = pipeline.unet
@@ -153,12 +192,10 @@ class LoRAImageTrainer:
153
  tokenizer = pipeline.tokenizer
154
  scheduler = pipeline.scheduler
155
 
156
- # Congelar parâmetros
157
  unet.requires_grad_(False)
158
  text_encoder.requires_grad_(False)
159
  vae.requires_grad_(False)
160
 
161
- # Criar configuração LoRA
162
  lora_config = LoraConfig(
163
  r=r,
164
  lora_alpha=lora_alpha,
@@ -167,18 +204,16 @@ class LoRAImageTrainer:
167
  bias="none"
168
  )
169
 
170
- # Aplicar LoRA ao UNet
171
  unet.add_adapter(lora_config, adapter_name="default")
172
  unet.set_adapter("default")
173
  unet.train()
174
  unet.to(self.device)
175
 
176
- # Otimizador
177
  optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
178
 
179
- # Preparar dados
180
  self.training_jobs[job_id]["status"] = "preparing_data"
181
  self.training_jobs[job_id]["progress"] = 20
 
182
 
183
  def preprocess_image(image):
184
  image = np.array(image).astype(np.float32) / 255.0
@@ -186,74 +221,65 @@ class LoRAImageTrainer:
186
  image = torch.from_numpy(image).unsqueeze(0)
187
  return image
188
 
189
- # Loop de treinamento
190
  total_steps = num_epochs * len(dataset)
191
  current_step = 0
192
 
193
  self.training_jobs[job_id]["status"] = "training"
194
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Iniciando treinamento real...")
 
195
 
196
  for epoch in range(num_epochs):
197
  for item in dataset:
198
  current_step += 1
199
 
200
- # Pré-processar imagem
201
  image = item["image"]
202
  caption = item["caption"]
203
  image_tensor = preprocess_image(image).to(self.device)
204
  if torch.cuda.is_available():
205
  image_tensor = image_tensor.half()
206
 
207
- # Codificar para latentes
208
  with torch.no_grad():
209
  latents = vae.encode(image_tensor * 2 - 1).latent_dist.sample() * 0.18215
210
 
211
- # Tokenizar texto
212
  inputs = tokenizer(caption, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
213
  input_ids = inputs.input_ids.to(self.device)
214
 
215
- # Gerar timesteps
216
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (1,), device=self.device).long()
217
  noise = torch.randn_like(latents)
218
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
219
 
220
- # Forward pass
221
  encoder_hidden_states = text_encoder(input_ids)[0]
222
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
223
 
224
- # Calcular e propagar perda
225
  loss = torch.nn.functional.mse_loss(noise_pred, noise)
226
  optimizer.zero_grad()
227
  loss.backward()
228
  optimizer.step()
229
 
230
- # ✅ LIMPEZA DE MEMÓRIA A CADA STEP
231
  if torch.cuda.is_available():
232
  torch.cuda.empty_cache()
233
 
234
- # Atualizar progresso
235
  progress = 30 + int((current_step / total_steps) * 60)
236
  self.training_jobs[job_id]["progress"] = min(progress, 90)
237
 
238
  if current_step % max(1, len(dataset)//2) == 0:
239
  log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
240
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - {log_msg}")
 
241
 
242
- # ✅ SALVAR APENAS OS ADAPTADORES (NÃO O MODELO INTEIRO)
243
  self.training_jobs[job_id]["status"] = "saving"
244
  self.training_jobs[job_id]["progress"] = 95
 
245
 
246
  output_dir = f"./lora_models/{job_id}"
247
  os.makedirs(output_dir, exist_ok=True)
248
 
249
- # 👇👇👇 CORREÇÃO FINAL: SALVA SÓ O LORA 👇👇👇
250
  unet.save_pretrained(
251
  output_dir,
252
  safe_serialization=True,
253
  selected_adapters=["default"]
254
  )
255
 
256
- # Criar adapter_config.json
257
  lora_config_dict = {
258
  "r": r,
259
  "lora_alpha": lora_alpha,
@@ -273,7 +299,6 @@ class LoRAImageTrainer:
273
  with open(f"{output_dir}/adapter_config.json", "w") as f:
274
  json.dump(lora_config_dict, f, indent=2)
275
 
276
- # README
277
  readme_content = f"""# LoRA Model - {job_id}
278
  Treinado com sucesso!
279
  Modelo Base: {model_name}
@@ -282,21 +307,23 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
282
  with open(f"{output_dir}/README.md", "w") as f:
283
  f.write(readme_content)
284
 
285
- # Finalizar
286
  self.training_jobs[job_id]["status"] = "completed"
287
  self.training_jobs[job_id]["progress"] = 100
288
  self.training_jobs[job_id]["model_path"] = output_dir
289
  self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
290
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ✅ Treinamento concluído! LoRA salvo em {output_dir}")
 
291
 
292
  logger.info(f"Treinamento LoRA concluído para job {job_id}")
293
 
294
  except Exception as e:
295
  error_msg = f"Erro no treinamento: {str(e)}"
296
  logger.error(error_msg)
297
- self.training_jobs[job_id]["status"] = "error"
298
- self.training_jobs[job_id]["error"] = error_msg
299
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ❌ {error_msg}")
 
 
300
 
301
  def start_training(self,
302
  model_name: str,
@@ -304,6 +331,7 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
304
  captions: List[str],
305
  **kwargs) -> str:
306
  job_id = str(uuid.uuid4())
 
307
  dataset = self.prepare_image_dataset(image_files, captions, kwargs.get('resolution', 512))
308
 
309
  self.training_jobs[job_id] = {
@@ -319,6 +347,9 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
319
  "completed_at": None
320
  }
321
 
 
 
 
322
  thread = threading.Thread(
323
  target=self.real_lora_training,
324
  args=(job_id, model_name, dataset),
@@ -329,9 +360,6 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
329
 
330
  return job_id
331
 
332
- def get_training_status(self, job_id: str) -> Dict[str, Any]:
333
- return self.training_jobs.get(job_id, {"error": "Job não encontrado"})
334
-
335
  def list_trained_models(self) -> List[Dict[str, str]]:
336
  models = []
337
  lora_models_dir = Path("./lora_models")
@@ -400,7 +428,6 @@ def create_gradio_interface():
400
  captions.append(f"{trigger_word.strip() or 'training image'}, high quality photo" if trigger_word.strip() else f"training image {len(captions) + 1}, high quality photo")
401
  captions = captions[:len(files)]
402
 
403
- # ✅ VALORES FIXOS DEFINIDOS AQUI DENTRO
404
  job_id = trainer.start_training(
405
  model_name=model_name,
406
  image_files=image_files,
@@ -465,7 +492,6 @@ def create_gradio_interface():
465
  learning_rate = gr.Slider(1e-5, 1e-3, 1e-4, step=1e-5, label="Taxa de Aprendizado")
466
  train_button = gr.Button("🚀 Iniciar Treinamento", variant="primary")
467
  train_output = gr.Textbox(label="📊 Resultado")
468
- # ✅ APENAS COMPONENTES GRADIO SÃO PASSADOS
469
  train_button.click(start_training_wrapper, [model_dropdown, image_files, captions_text, trigger_word, r, lora_alpha, num_epochs, learning_rate], train_output)
470
 
471
  with gr.TabItem("📊 Status"):
@@ -486,6 +512,17 @@ def create_gradio_interface():
486
  return interface
487
 
488
  if __name__ == "__main__":
 
489
  os.makedirs("./lora_models", exist_ok=True)
 
 
 
490
  interface = create_gradio_interface()
491
- interface.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
 
 
 
 
 
 
 
35
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  self.training_jobs = {}
37
  self.models_cache = {}
38
+ # ✅ Criar pasta para persistência de jobs
39
+ Path("./jobs").mkdir(exist_ok=True)
40
 
41
+ def _save_job_state(self, job_id: str):
42
+ """Salva o estado do job em disco."""
43
+ job_file = Path(f"./jobs/{job_id}.json")
44
+ with open(job_file, "w") as f:
45
+ json.dump(self.training_jobs[job_id], f, indent=2, default=str)
46
+
47
+ def _load_job_state(self, job_id: str) -> Optional[Dict]:
48
+ """Carrega o estado do job do disco."""
49
+ job_file = Path(f"./jobs/{job_id}.json")
50
+ if job_file.exists():
51
+ try:
52
+ with open(job_file, "r") as f:
53
+ return json.load(f)
54
+ except Exception as e:
55
+ logger.error(f"Erro ao carregar job {job_id}: {e}")
56
+ return None
57
+
58
+ def get_training_status(self, job_id: str) -> Dict[str, Any]:
59
+ """Retorna status do treinamento, carregando do disco se necessário."""
60
+ if job_id in self.training_jobs:
61
+ return self.training_jobs[job_id]
62
+ else:
63
+ loaded_job = self._load_job_state(job_id)
64
+ if loaded_job:
65
+ self.training_jobs[job_id] = loaded_job
66
+ return loaded_job
67
+ return {"error": "Job não encontrado"}
68
+
69
  def get_available_models(self) -> List[str]:
70
  """Retorna lista de modelos base disponíveis para treinamento LoRA."""
71
  return [
 
82
 
83
  logger.info(f"Carregando modelo base: {model_name}")
84
 
 
85
  model_kwargs = {
86
  "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
87
  "use_safetensors": True,
 
89
  "safety_checker": None,
90
  }
91
 
 
92
  pipeline = StableDiffusionPipeline.from_pretrained(
93
  model_name,
94
  **model_kwargs
 
101
  pipeline.enable_xformers_memory_efficient_attention()
102
  except:
103
  logger.warning("xformers não disponível")
 
104
  pipeline.unet.enable_gradient_checkpointing()
105
 
106
  self.models_cache[model_name] = pipeline
 
111
  raise e
112
 
113
  def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:
 
114
  dataset = []
115
 
116
  for img_path, caption in zip(image_files, captions):
 
131
  return dataset
132
 
133
  def resize_image(self, image: Image.Image, target_size: int) -> Image.Image:
 
134
  width, height = image.size
135
 
136
  if width > height:
 
162
  learning_rate: float = 1e-4,
163
  batch_size: int = 1,
164
  resolution: int = 512) -> None:
 
 
165
  try:
166
+ # Inicializar job se não existir
167
+ if job_id not in self.training_jobs:
168
+ self.training_jobs[job_id] = {
169
+ "id": job_id,
170
+ "status": "queued",
171
+ "progress": 0,
172
+ "created_at": datetime.now().isoformat(),
173
+ "model_name": model_name,
174
+ "num_images": len(dataset),
175
+ "logs": [],
176
+ "error": None,
177
+ "model_path": None,
178
+ "completed_at": None
179
+ }
180
+
181
  # Atualizar status
182
  self.training_jobs[job_id]["status"] = "loading_model"
183
  self.training_jobs[job_id]["progress"] = 5
184
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Carregando modelo base: {model_name}")
185
+ self._save_job_state(job_id) # ✅ SALVAR ESTADO
186
+
187
  # Carregar modelo base
188
  pipeline = self.load_base_model(model_name)
189
  unet = pipeline.unet
 
192
  tokenizer = pipeline.tokenizer
193
  scheduler = pipeline.scheduler
194
 
 
195
  unet.requires_grad_(False)
196
  text_encoder.requires_grad_(False)
197
  vae.requires_grad_(False)
198
 
 
199
  lora_config = LoraConfig(
200
  r=r,
201
  lora_alpha=lora_alpha,
 
204
  bias="none"
205
  )
206
 
 
207
  unet.add_adapter(lora_config, adapter_name="default")
208
  unet.set_adapter("default")
209
  unet.train()
210
  unet.to(self.device)
211
 
 
212
  optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
213
 
 
214
  self.training_jobs[job_id]["status"] = "preparing_data"
215
  self.training_jobs[job_id]["progress"] = 20
216
+ self._save_job_state(job_id) # ✅ SALVAR ESTADO
217
 
218
  def preprocess_image(image):
219
  image = np.array(image).astype(np.float32) / 255.0
 
221
  image = torch.from_numpy(image).unsqueeze(0)
222
  return image
223
 
 
224
  total_steps = num_epochs * len(dataset)
225
  current_step = 0
226
 
227
  self.training_jobs[job_id]["status"] = "training"
228
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Iniciando treinamento real...")
229
+ self._save_job_state(job_id) # ✅ SALVAR ESTADO
230
 
231
  for epoch in range(num_epochs):
232
  for item in dataset:
233
  current_step += 1
234
 
 
235
  image = item["image"]
236
  caption = item["caption"]
237
  image_tensor = preprocess_image(image).to(self.device)
238
  if torch.cuda.is_available():
239
  image_tensor = image_tensor.half()
240
 
 
241
  with torch.no_grad():
242
  latents = vae.encode(image_tensor * 2 - 1).latent_dist.sample() * 0.18215
243
 
 
244
  inputs = tokenizer(caption, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
245
  input_ids = inputs.input_ids.to(self.device)
246
 
 
247
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (1,), device=self.device).long()
248
  noise = torch.randn_like(latents)
249
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
250
 
 
251
  encoder_hidden_states = text_encoder(input_ids)[0]
252
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
253
 
 
254
  loss = torch.nn.functional.mse_loss(noise_pred, noise)
255
  optimizer.zero_grad()
256
  loss.backward()
257
  optimizer.step()
258
 
 
259
  if torch.cuda.is_available():
260
  torch.cuda.empty_cache()
261
 
 
262
  progress = 30 + int((current_step / total_steps) * 60)
263
  self.training_jobs[job_id]["progress"] = min(progress, 90)
264
 
265
  if current_step % max(1, len(dataset)//2) == 0:
266
  log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
267
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - {log_msg}")
268
+ self._save_job_state(job_id) # ✅ SALVAR ESTADO (opcional aqui para logs)
269
 
 
270
  self.training_jobs[job_id]["status"] = "saving"
271
  self.training_jobs[job_id]["progress"] = 95
272
+ self._save_job_state(job_id) # ✅ SALVAR ESTADO
273
 
274
  output_dir = f"./lora_models/{job_id}"
275
  os.makedirs(output_dir, exist_ok=True)
276
 
 
277
  unet.save_pretrained(
278
  output_dir,
279
  safe_serialization=True,
280
  selected_adapters=["default"]
281
  )
282
 
 
283
  lora_config_dict = {
284
  "r": r,
285
  "lora_alpha": lora_alpha,
 
299
  with open(f"{output_dir}/adapter_config.json", "w") as f:
300
  json.dump(lora_config_dict, f, indent=2)
301
 
 
302
  readme_content = f"""# LoRA Model - {job_id}
303
  Treinado com sucesso!
304
  Modelo Base: {model_name}
 
307
  with open(f"{output_dir}/README.md", "w") as f:
308
  f.write(readme_content)
309
 
 
310
  self.training_jobs[job_id]["status"] = "completed"
311
  self.training_jobs[job_id]["progress"] = 100
312
  self.training_jobs[job_id]["model_path"] = output_dir
313
  self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
314
  self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ✅ Treinamento concluído! LoRA salvo em {output_dir}")
315
+ self._save_job_state(job_id) # ✅ SALVAR ESTADO FINAL
316
 
317
  logger.info(f"Treinamento LoRA concluído para job {job_id}")
318
 
319
  except Exception as e:
320
  error_msg = f"Erro no treinamento: {str(e)}"
321
  logger.error(error_msg)
322
+ if job_id in self.training_jobs:
323
+ self.training_jobs[job_id]["status"] = "error"
324
+ self.training_jobs[job_id]["error"] = error_msg
325
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ❌ {error_msg}")
326
+ self._save_job_state(job_id) # ✅ SALVAR ESTADO DE ERRO
327
 
328
  def start_training(self,
329
  model_name: str,
 
331
  captions: List[str],
332
  **kwargs) -> str:
333
  job_id = str(uuid.uuid4())
334
+
335
  dataset = self.prepare_image_dataset(image_files, captions, kwargs.get('resolution', 512))
336
 
337
  self.training_jobs[job_id] = {
 
347
  "completed_at": None
348
  }
349
 
350
+ # ✅ Salvar estado inicial
351
+ self._save_job_state(job_id)
352
+
353
  thread = threading.Thread(
354
  target=self.real_lora_training,
355
  args=(job_id, model_name, dataset),
 
360
 
361
  return job_id
362
 
 
 
 
363
  def list_trained_models(self) -> List[Dict[str, str]]:
364
  models = []
365
  lora_models_dir = Path("./lora_models")
 
428
  captions.append(f"{trigger_word.strip() or 'training image'}, high quality photo" if trigger_word.strip() else f"training image {len(captions) + 1}, high quality photo")
429
  captions = captions[:len(files)]
430
 
 
431
  job_id = trainer.start_training(
432
  model_name=model_name,
433
  image_files=image_files,
 
492
  learning_rate = gr.Slider(1e-5, 1e-3, 1e-4, step=1e-5, label="Taxa de Aprendizado")
493
  train_button = gr.Button("🚀 Iniciar Treinamento", variant="primary")
494
  train_output = gr.Textbox(label="📊 Resultado")
 
495
  train_button.click(start_training_wrapper, [model_dropdown, image_files, captions_text, trigger_word, r, lora_alpha, num_epochs, learning_rate], train_output)
496
 
497
  with gr.TabItem("📊 Status"):
 
512
  return interface
513
 
514
  if __name__ == "__main__":
515
+ # ✅ Criar diretórios necessários
516
  os.makedirs("./lora_models", exist_ok=True)
517
+ os.makedirs("./jobs", exist_ok=True) # Pasta para persistência de jobs
518
+
519
+ # Configurar interface
520
  interface = create_gradio_interface()
521
+
522
+ # Lançar aplicação
523
+ interface.launch(
524
+ server_name="0.0.0.0",
525
+ server_port=7860,
526
+ show_error=True,
527
+ quiet=False
528
+ )