Allex21 commited on
Commit
a82fdb3
·
verified ·
1 Parent(s): 10a7187

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -556
app.py CHANGED
@@ -35,33 +35,31 @@ class LoRAImageTrainer:
35
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  self.training_jobs = {}
37
  self.models_cache = {}
38
-
 
 
 
39
  def get_available_models(self) -> List[str]:
40
- """Retorna lista de modelos base disponíveis para treinamento LoRA."""
41
  return [
42
  "runwayml/stable-diffusion-v1-5",
43
  "stabilityai/stable-diffusion-2-1",
44
  "CompVis/stable-diffusion-v1-4"
45
- # XL removido por ser pesado demais para Spaces gratuitos
46
  ]
47
 
48
  def load_base_model(self, model_name: str):
49
- """Carrega modelo base de difusão com otimizações para baixo uso de GPU."""
50
  try:
51
  if model_name in self.models_cache:
52
  return self.models_cache[model_name]
53
 
54
  logger.info(f"Carregando modelo base: {model_name}")
55
 
56
- # Configurações para otimização de memória
57
  model_kwargs = {
58
  "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
59
  "use_safetensors": True,
60
  "variant": "fp16" if torch.cuda.is_available() else None,
61
- "safety_checker": None, # Desativa verificador de segurança para economizar memória
62
  }
63
 
64
- # Carregar pipeline completo
65
  pipeline = StableDiffusionPipeline.from_pretrained(
66
  model_name,
67
  **model_kwargs
@@ -69,17 +67,14 @@ class LoRAImageTrainer:
69
 
70
  if torch.cuda.is_available():
71
  pipeline = pipeline.to(self.device)
72
- # Habilitar attention slicing para economia de memória
73
  pipeline.enable_attention_slicing()
74
- # Habilitar memory efficient attention se disponível
75
  try:
76
  pipeline.enable_xformers_memory_efficient_attention()
77
- except Exception as e:
78
- logger.warning("xformers não disponível, usando attention padrão")
 
79
 
80
- # Cache do modelo
81
  self.models_cache[model_name] = pipeline
82
-
83
  return pipeline
84
 
85
  except Exception as e:
@@ -87,15 +82,11 @@ class LoRAImageTrainer:
87
  raise e
88
 
89
  def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:
90
- """Prepara dataset de imagens para treinamento."""
91
  dataset = []
92
 
93
  for img_path, caption in zip(image_files, captions):
94
  try:
95
- # Carregar e redimensionar imagem
96
  image = Image.open(img_path).convert("RGB")
97
-
98
- # Redimensionar mantendo aspect ratio
99
  image = self.resize_image(image, resolution)
100
 
101
  dataset.append({
@@ -111,10 +102,8 @@ class LoRAImageTrainer:
111
  return dataset
112
 
113
  def resize_image(self, image: Image.Image, target_size: int) -> Image.Image:
114
- """Redimensiona imagem mantendo aspect ratio e fazendo crop central se necessário."""
115
  width, height = image.size
116
 
117
- # Calcular novo tamanho mantendo aspect ratio
118
  if width > height:
119
  new_width = target_size
120
  new_height = int((height * target_size) / width)
@@ -122,16 +111,13 @@ class LoRAImageTrainer:
122
  new_height = target_size
123
  new_width = int((width * target_size) / height)
124
 
125
- # Redimensionar
126
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
127
 
128
- # Crop central para obter tamanho exato
129
  if new_width != target_size or new_height != target_size:
130
  left = (new_width - target_size) // 2
131
  top = (new_height - target_size) // 2
132
  right = left + target_size
133
  bottom = top + target_size
134
-
135
  image = image.crop((left, top, right, bottom))
136
 
137
  return image
@@ -140,22 +126,35 @@ class LoRAImageTrainer:
140
  job_id: str,
141
  model_name: str,
142
  dataset: List[Dict],
143
- r: int = 16,
144
- lora_alpha: int = 32,
145
- lora_dropout: float = 0.1,
146
- num_epochs: int = 10,
147
  learning_rate: float = 1e-4,
148
  batch_size: int = 1,
149
  resolution: int = 512) -> None:
150
- """TREINAMENTO REAL DE LoRA PARA IMAGENS - CORRIGIDO PARA DIFFUSERS + PEFT."""
151
-
152
  try:
153
- # Atualizar status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  self.training_jobs[job_id]["status"] = "loading_model"
155
  self.training_jobs[job_id]["progress"] = 5
156
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Carregando modelo base: {model_name}")
157
-
158
- # Carregar modelo base
 
159
  pipeline = self.load_base_model(model_name)
160
  unet = pipeline.unet
161
  text_encoder = pipeline.text_encoder
@@ -163,12 +162,14 @@ class LoRAImageTrainer:
163
  tokenizer = pipeline.tokenizer
164
  scheduler = pipeline.scheduler
165
 
166
- # Congelar parâmetros
167
  unet.requires_grad_(False)
168
  text_encoder.requires_grad_(False)
169
  vae.requires_grad_(False)
170
 
171
- # Criar configuração LoRA
 
 
 
172
  lora_config = LoraConfig(
173
  r=r,
174
  lora_alpha=lora_alpha,
@@ -177,101 +178,96 @@ class LoRAImageTrainer:
177
  bias="none"
178
  )
179
 
180
- # Aplicar LoRA ao UNet manualmente, sem usar get_peft_model diretamente
181
  unet.add_adapter(lora_config, adapter_name="default")
182
-
183
- # Ativar o adaptador
184
  unet.set_adapter("default")
 
 
 
 
 
 
 
185
  unet.train()
186
  unet.to(self.device)
187
 
188
- # Otimizador
189
- optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
190
 
191
- # Preparar scheduler para treinamento
192
  self.training_jobs[job_id]["status"] = "preparing_data"
193
  self.training_jobs[job_id]["progress"] = 20
194
 
195
- # Normalização de imagem
196
  def preprocess_image(image):
197
  image = np.array(image).astype(np.float32) / 255.0
198
  image = image.transpose(2, 0, 1)
199
  image = torch.from_numpy(image).unsqueeze(0)
200
  return image
201
 
202
- # Loop de treinamento real
203
  total_steps = num_epochs * len(dataset)
204
  current_step = 0
205
 
206
  self.training_jobs[job_id]["status"] = "training"
207
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Iniciando treinamento real...")
 
 
208
 
209
  for epoch in range(num_epochs):
210
  for item in dataset:
211
  current_step += 1
212
 
213
- # Obter imagem e legenda
214
  image = item["image"]
215
  caption = item["caption"]
216
-
217
- # Pré-processar imagem
218
  image_tensor = preprocess_image(image).to(self.device)
219
  if torch.cuda.is_available():
220
  image_tensor = image_tensor.half()
221
 
222
- # Codificar imagem para latentes
223
  with torch.no_grad():
224
  latents = vae.encode(image_tensor * 2 - 1).latent_dist.sample() * 0.18215
225
 
226
- # Tokenizar texto
227
  inputs = tokenizer(caption, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
228
  input_ids = inputs.input_ids.to(self.device)
229
 
230
- # Gerar timesteps aleatórios
231
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (1,), device=self.device).long()
232
-
233
- # Adicionar ruído aos latentes
234
  noise = torch.randn_like(latents)
235
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
236
 
237
- # Forward pass
238
  encoder_hidden_states = text_encoder(input_ids)[0]
239
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
240
 
241
- # Calcular perda
242
  loss = torch.nn.functional.mse_loss(noise_pred, noise)
243
-
244
- # Backward pass
245
  optimizer.zero_grad()
246
  loss.backward()
247
  optimizer.step()
248
 
249
- # Atualizar progresso
 
 
250
  progress = 30 + int((current_step / total_steps) * 60)
251
  self.training_jobs[job_id]["progress"] = min(progress, 90)
252
 
253
  if current_step % max(1, len(dataset)//2) == 0:
254
  log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
255
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - {log_msg}")
 
256
 
257
- # Salvar LoRA
258
  self.training_jobs[job_id]["status"] = "saving"
259
  self.training_jobs[job_id]["progress"] = 95
260
 
261
  output_dir = f"./lora_models/{job_id}"
262
  os.makedirs(output_dir, exist_ok=True)
263
 
264
- # Salvar apenas os adaptadores LoRA
265
- unet.save_pretrained(output_dir)
 
 
 
 
266
 
267
- # Criar adapter_config.json
268
  lora_config_dict = {
269
  "r": r,
270
  "lora_alpha": lora_alpha,
271
  "target_modules": ["to_k", "to_q", "to_v", "to_out.0"],
272
  "lora_dropout": lora_dropout,
273
  "bias": "none",
274
- "task_type": "CAUSAL_LM", # Mantido por compatibilidade, mas não é usado
275
  "base_model_name": model_name,
276
  "training_info": {
277
  "num_epochs": num_epochs,
@@ -284,56 +280,37 @@ class LoRAImageTrainer:
284
  with open(f"{output_dir}/adapter_config.json", "w") as f:
285
  json.dump(lora_config_dict, f, indent=2)
286
 
287
- # README
288
  readme_content = f"""# LoRA Model - {job_id}
289
-
290
- Informações do Treinamento
291
-
292
  Modelo Base: {model_name}
293
- Rank (r): {r}
294
- LoRA Alpha: {lora_alpha}
295
- Dropout: {lora_dropout}
296
- Épocas: {num_epochs}
297
- Taxa de Aprendizado: {learning_rate}
298
- Resolução: {resolution}x{resolution}
299
- Número de Imagens: {len(dataset)}
300
- Data de Treinamento: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
301
-
302
- Como Usar
303
-
304
- 1. Baixe os arquivos adapter_config.json e adapter_model.safetensors
305
- 2. Carregue em sua ferramenta de geração de imagens favorita (ComfyUI, Automatic1111, etc.)
306
- 3. Use o trigger word ou estilo aprendido durante o treinamento
307
  """
308
  with open(f"{output_dir}/README.md", "w") as f:
309
  f.write(readme_content)
310
 
311
- # Finalizar
312
  self.training_jobs[job_id]["status"] = "completed"
313
  self.training_jobs[job_id]["progress"] = 100
314
  self.training_jobs[job_id]["model_path"] = output_dir
315
  self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
316
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ✅ Treinamento REAL concluído! LoRA salvo em {output_dir}")
317
-
318
- logger.info(f"Treinamento LoRA REAL concluído para job {job_id}")
319
 
320
  except Exception as e:
321
- error_msg = f"Erro REAL no treinamento: {str(e)}"
322
  logger.error(error_msg)
323
- self.training_jobs[job_id]["status"] = "error"
324
- self.training_jobs[job_id]["error"] = error_msg
325
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ❌ {error_msg}")
 
326
 
327
  def start_training(self,
328
  model_name: str,
329
  image_files: List[str],
330
  captions: List[str],
331
  **kwargs) -> str:
332
- """Inicia treinamento LoRA assíncrono."""
333
-
334
  job_id = str(uuid.uuid4())
335
 
336
- # Preparar dataset
337
  dataset = self.prepare_image_dataset(image_files, captions, kwargs.get('resolution', 512))
338
 
339
  self.training_jobs[job_id] = {
@@ -349,7 +326,6 @@ Como Usar
349
  "completed_at": None
350
  }
351
 
352
- # Iniciar treinamento em thread separada
353
  thread = threading.Thread(
354
  target=self.real_lora_training,
355
  args=(job_id, model_name, dataset),
@@ -361,11 +337,9 @@ Como Usar
361
  return job_id
362
 
363
  def get_training_status(self, job_id: str) -> Dict[str, Any]:
364
- """Retorna status do treinamento."""
365
  return self.training_jobs.get(job_id, {"error": "Job não encontrado"})
366
 
367
  def list_trained_models(self) -> List[Dict[str, str]]:
368
- """Lista modelos LoRA treinados."""
369
  models = []
370
  lora_models_dir = Path("./lora_models")
371
 
@@ -377,7 +351,6 @@ Como Usar
377
  try:
378
  with open(config_file, 'r') as f:
379
  config = json.load(f)
380
-
381
  models.append({
382
  "id": model_dir.name,
383
  "path": str(model_dir),
@@ -385,7 +358,7 @@ Como Usar
385
  "r": config.get("r", "Unknown"),
386
  "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
387
  })
388
- except Exception as e:
389
  models.append({
390
  "id": model_dir.name,
391
  "path": str(model_dir),
@@ -393,531 +366,135 @@ Como Usar
393
  "r": "Unknown",
394
  "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
395
  })
396
-
397
  return models
398
 
399
  def create_download_zip(self, model_path: str) -> str:
400
- """Cria um arquivo ZIP com os arquivos do modelo LoRA para download."""
401
  zip_path = f"{model_path}.zip"
402
-
403
  with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
404
  model_dir = Path(model_path)
405
  for file_path in model_dir.rglob('*'):
406
  if file_path.is_file():
407
  arcname = file_path.relative_to(model_dir)
408
  zipf.write(file_path, arcname)
409
-
410
  return zip_path
411
 
412
 
413
- # Instância global do trainer
414
  trainer = LoRAImageTrainer()
415
 
416
  def create_gradio_interface():
417
- """Cria interface Gradio para a ferramenta LoRA de geração de imagens."""
418
-
419
- # CSS personalizado para responsividade móvel
420
  custom_css = """
421
- /* Mobile-first responsive design */
422
  @media (max-width: 768px) {
423
- .gradio-container {
424
- padding: 8px !important;
425
- margin: 0 !important;
426
- }
427
-
428
- .tab-nav {
429
- flex-wrap: wrap !important;
430
- gap: 4px !important;
431
- }
432
-
433
- .tab-nav button {
434
- font-size: 14px !important;
435
- padding: 8px 12px !important;
436
- min-width: auto !important;
437
- flex: 1 1 auto !important;
438
- }
439
-
440
- .form-container {
441
- padding: 12px !important;
442
- }
443
-
444
- .btn {
445
- width: 100% !important;
446
- padding: 12px !important;
447
- font-size: 16px !important;
448
- margin-bottom: 8px !important;
449
- min-height: 44px !important;
450
- }
451
-
452
- .textbox textarea {
453
- font-size: 16px !important;
454
- min-height: 120px !important;
455
- }
456
-
457
- .dropdown select {
458
- font-size: 16px !important;
459
- padding: 12px !important;
460
- }
461
-
462
- .output-text {
463
- font-size: 14px !important;
464
- line-height: 1.5 !important;
465
- }
466
-
467
- .column {
468
- margin-bottom: 16px !important;
469
- }
470
-
471
- .file-upload {
472
- min-height: 100px !important;
473
- }
474
  }
475
-
476
- /* Enhanced visual styles */
477
  .lora-header {
478
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
479
- color: white;
480
- padding: 20px;
481
- border-radius: 12px;
482
- margin-bottom: 20px;
483
- text-align: center;
484
- box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
485
  }
486
-
487
- .status-indicator {
488
- display: inline-block;
489
- padding: 4px 8px;
490
- border-radius: 6px;
491
- font-size: 12px;
492
- font-weight: 600;
493
- text-transform: uppercase;
494
- letter-spacing: 0.5px;
495
- margin-right: 8px;
496
- }
497
-
498
- .status-queued { background-color: #fbbf24; color: #92400e; }
499
- .status-loading_model { background-color: #60a5fa; color: #1e40af; }
500
- .status-preparing_lora { background-color: #8b5cf6; color: #5b21b6; }
501
- .status-preparing_data { background-color: #06b6d4; color: #0e7490; }
502
- .status-training { background-color: #a78bfa; color: #5b21b6; }
503
- .status-saving { background-color: #f59e0b; color: #92400e; }
504
  .status-completed { background-color: #34d399; color: #065f46; }
505
  .status-error { background-color: #f87171; color: #991b1b; }
506
-
507
- /* Touch device optimizations */
508
- @media (hover: none) and (pointer: coarse) {
509
- .btn {
510
- min-height: 44px !important;
511
- min-width: 44px !important;
512
- }
513
-
514
- .tab-nav button {
515
- min-height: 44px !important;
516
- min-width: 44px !important;
517
- }
518
- }
519
  """
520
 
521
- def process_images_and_captions(files, captions_text):
522
- """Processa imagens e legendas enviadas pelo usuário."""
523
- if not files:
524
- return "❌ Erro: Nenhuma imagem foi enviada!"
525
-
526
- # Processar legendas
527
- captions = []
528
- if captions_text.strip():
529
- captions = [line.strip() for line in captions_text.split('\n') if line.strip()]
530
-
531
- # Se não há legendas suficientes, usar legendas padrão
532
- while len(captions) < len(files):
533
- captions.append(f"training image {len(captions) + 1}")
534
-
535
- # Truncar legendas se houver mais que imagens
536
- captions = captions[:len(files)]
537
-
538
- return files, captions
539
-
540
- def start_training_wrapper(model_name, files, captions_text, trigger_word, r, lora_alpha, lora_dropout,
541
- num_epochs, learning_rate, batch_size, resolution):
542
- """Wrapper para iniciar treinamento via Gradio."""
543
-
544
- if not files:
545
- return "❌ Erro: Nenhuma imagem foi enviada para treinamento!"
546
-
547
- if len(files) < 3:
548
- return "❌ Erro: Forneça pelo menos 3 imagens para treinamento!"
549
 
550
  try:
551
- # Processar imagens e legendas
552
  image_files = [f.name for f in files]
553
-
554
- # Processar legendas
555
- captions = []
556
- if captions_text.strip():
557
- captions = [line.strip() for line in captions_text.split('\n') if line.strip()]
558
-
559
- # Se não há legendas suficientes, usar trigger word + descrição padrão
560
  while len(captions) < len(files):
561
- if trigger_word.strip():
562
- captions.append(f"{trigger_word.strip()}, high quality photo")
563
- else:
564
- captions.append(f"training image {len(captions) + 1}, high quality photo")
565
-
566
- # Truncar legendas se houver mais que imagens
567
  captions = captions[:len(files)]
568
 
569
  job_id = trainer.start_training(
570
- model_name=model_name,
571
- image_files=image_files,
572
  captions=captions,
573
- r=int(r),
574
- lora_alpha=int(lora_alpha),
575
- lora_dropout=float(lora_dropout),
576
- num_epochs=int(num_epochs),
577
  learning_rate=float(learning_rate),
578
- batch_size=int(batch_size),
579
- resolution=int(resolution)
580
  )
581
-
582
- return f"✅ Treinamento REAL iniciado! ID do Job: {job_id}\n\n📊 Imagens: {len(files)}\n🏷️ Trigger Word: {trigger_word or 'Nenhuma'}\n\nUse o ID acima para verificar o progresso na aba 'Status do Treinamento'."
583
-
584
  except Exception as e:
585
- return f"❌ Erro ao iniciar treinamento: {str(e)}"
586
 
587
  def check_status_wrapper(job_id):
588
- """Wrapper para verificar status via Gradio."""
589
- if not job_id.strip():
590
- return "❌ Erro: Forneça um ID de job válido!"
591
-
592
  status = trainer.get_training_status(job_id.strip())
 
593
 
594
- if "error" in status and status["error"] == "Job não encontrado":
595
- return "❌ Job não encontrado! Verifique o ID."
596
-
597
- # Criar indicador visual de status
598
- status_class = f"status-{status['status']}"
599
- status_emoji = {
600
- 'queued': '⏳',
601
- 'loading_model': '📥',
602
- 'preparing_lora': '⚙️',
603
- 'preparing_data': '📊',
604
- 'training': '🏋️',
605
- 'saving': '💾',
606
- 'completed': '✅',
607
- 'error': '❌'
608
- }.get(status['status'], '📊')
609
-
610
- # Barra de progresso visual
611
- progress = status['progress']
612
- progress_bar = f"""
613
- <div style="width: 100%; background-color: #e5e7eb; border-radius: 4px; overflow: hidden; margin: 8px 0;">
614
- <div style="width: {progress}%; height: 8px; background: linear-gradient(90deg, #3b82f6, #8b5cf6); transition: width 0.3s ease; border-radius: 4px;"></div>
615
- </div>
616
- """
617
-
618
- status_text = f"""
619
-
620
- 📊 Status do Treinamento LoRA
621
-
622
- 🆔 Job ID: {status['id']}
623
- {status_emoji} Status: <span class="{status_class}">{status['status'].upper().replace('_', ' ')}</span>
624
- ⏳ Progresso: {status['progress']}%
625
-
626
- {progress_bar}
627
-
628
- 🤖 Modelo Base: {status['model_name']}
629
- 🖼️ Imagens: {status.get('num_images', 'N/A')}
630
- 📅 Criado em: {status['created_at']}
631
-
632
- """
633
-
634
- if status['logs']:
635
- status_text += "📝 **Logs Recentes:**\n"
636
- for log in status['logs'][-5:]: # Últimos 5 logs
637
- status_text += f"• {log}\n"
638
-
639
- if status['status'] == 'completed':
640
- status_text += f"\n✅ **Treinamento Concluído!**\n📁 **Modelo salvo em:** {status['model_path']}"
641
- status_text += f"\n⏰ **Concluído em:** {status['completed_at']}"
642
- status_text += f"\n\n💡 **Próximos passos:** Vá para a aba 'Modelos Treinados' para baixar seu LoRA!"
643
- elif status['status'] == 'error':
644
- status_text += f"\n❌ **Erro:** {status['error']}"
645
 
 
 
 
 
646
  return status_text
647
 
648
  def list_models_wrapper():
649
- """Wrapper para listar modelos via Gradio."""
650
  models = trainer.list_trained_models()
651
-
652
- if not models:
653
- return "📭 Nenhum modelo LoRA treinado encontrado."
654
-
655
- models_text = "📚 **Modelos LoRA Treinados:**\n\n"
656
- for model in models:
657
- models_text += f"🆔 **ID:** {model['id']}\n"
658
- models_text += f"🤖 **Modelo Base:** {model['base_model']}\n"
659
- models_text += f"📊 **Rank (r):** {model['r']}\n"
660
- models_text += f"📁 **Caminho:** {model['path']}\n"
661
- models_text += f"📅 **Criado:** {model['created']}\n\n"
662
- models_text += "---\n\n"
663
-
664
- return models_text
665
 
666
  def download_model_wrapper(job_id):
667
- """Wrapper para preparar download do modelo."""
668
- if not job_id.strip():
669
- return None, "❌ Erro: Forneça um ID de job válido!"
670
-
671
  status = trainer.get_training_status(job_id.strip())
672
-
673
- if "error" in status and status["error"] == "Job não encontrado":
674
- return None, "❌ Job não encontrado! Verifique o ID."
675
-
676
- if status['status'] != 'completed':
677
- return None, f"❌ Treinamento ainda não foi concluído. Status atual: {status['status']}"
678
-
679
  try:
680
- model_path = status['model_path']
681
- zip_path = trainer.create_download_zip(model_path)
682
-
683
- return zip_path, f"✅ Arquivo ZIP criado com sucesso! Clique no link acima para baixar."
684
-
685
  except Exception as e:
686
- return None, f"❌ Erro ao criar arquivo de download: {str(e)}"
687
 
688
- # Interface Gradio
689
- with gr.Blocks(
690
- title="🎨 LoRA Image Trainer - Criador e Treinador de LoRA para Imagens",
691
- theme=gr.themes.Soft(),
692
- css=custom_css
693
- ) as interface:
694
-
695
- gr.HTML("""
696
- <div class="lora-header">
697
- <h1>🎨 LoRA Image Trainer</h1>
698
- <p>Criador e Treinador de LoRA para Geração de Imagens</p>
699
- <p style="font-size: 0.9em; opacity: 0.9; margin-top: 8px;">
700
- Ferramenta otimizada para baixo uso de GPU, compatível com dispositivos móveis
701
- </p>
702
- </div>
703
- """)
704
 
705
  with gr.Tabs():
706
-
707
- # Aba de Treinamento
708
- with gr.TabItem("🎯 Treinar LoRA"):
709
- gr.Markdown("### Configurar e Iniciar Treinamento LoRA para Imagens")
710
-
711
  with gr.Row():
712
- with gr.Column(scale=2):
713
- model_dropdown = gr.Dropdown(
714
- choices=trainer.get_available_models(),
715
- value="runwayml/stable-diffusion-v1-5",
716
- label="🤖 Modelo Base",
717
- )
718
-
719
- image_files = gr.File(
720
- file_count="multiple",
721
- file_types=["image"],
722
- label="🖼️ Imagens de Treinamento",
723
- )
724
-
725
- trigger_word = gr.Textbox(
726
- label="🏷️ Trigger Word (Opcional)",
727
- placeholder="ex: meuEstilo, minhaPersonagem, etc.",
728
- )
729
-
730
- captions_text = gr.Textbox(
731
- lines=8,
732
- placeholder="Digite uma legenda por linha (opcional)...\n\nExemplo:\nmeuEstilo, retrato de uma mulher\nmeuEstilo, homem sorrindo\nmeuEstilo, paisagem urbana\n\nSe deixar vazio, usará a trigger word + 'high quality photo'",
733
- label="📝 Legendas das Imagens (Opcional)",
734
- )
735
-
736
- with gr.Column(scale=1):
737
- gr.Markdown("### ⚙️ Parâmetros LoRA")
738
-
739
- r = gr.Slider(
740
- minimum=4, maximum=64, value=8, step=4, # reduzido max para 64
741
- label="r (Rank)",
742
- )
743
-
744
- lora_alpha = gr.Slider(
745
- minimum=1, maximum=64, value=16, step=1, # reduzido max para 64
746
- label="LoRA Alpha",
747
- )
748
-
749
- lora_dropout = gr.Slider(
750
- minimum=0.0, maximum=0.5, value=0.0, step=0.05, # dropout 0 para mais estabilidade
751
- label="LoRA Dropout",
752
- )
753
-
754
- gr.Markdown("### 🏋️ Parâmetros de Treinamento")
755
-
756
- num_epochs = gr.Slider(
757
- minimum=5, maximum=20, value=10, step=5, # reduzido max para 20
758
- label="Épocas",
759
- )
760
-
761
- learning_rate = gr.Slider(
762
- minimum=1e-5, maximum=5e-4, value=1e-4, step=1e-5, # reduzido max
763
- label="Taxa de Aprendizado",
764
- )
765
-
766
- batch_size = gr.Slider(
767
- minimum=1, maximum=1, value=1, step=1, # fixado em 1 para Spaces
768
- label="Batch Size",
769
- )
770
-
771
- resolution = gr.Dropdown(
772
- choices=[512], # fixado em 512 para garantir funcionamento em GPU limitada
773
- value=512,
774
- label="Resolução",
775
- )
776
 
777
- train_button = gr.Button("🚀 Iniciar Treinamento LoRA", variant="primary", size="lg")
778
- train_output = gr.Textbox(label="📊 Resultado", lines=5)
779
-
780
- train_button.click(
781
- start_training_wrapper,
782
- inputs=[model_dropdown, image_files, captions_text, trigger_word, r, lora_alpha, lora_dropout,
783
- num_epochs, learning_rate, batch_size, resolution],
784
- outputs=train_output
785
- )
786
 
787
- # Aba de Status
788
- with gr.TabItem("📊 Status do Treinamento"):
789
- gr.Markdown("### Verificar Progresso do Treinamento")
790
-
791
- job_id_input = gr.Textbox(
792
- label="🆔 ID do Job",
793
- placeholder="Cole aqui o ID do job de treinamento...",
794
- )
795
-
796
- status_button = gr.Button("🔍 Verificar Status", variant="secondary")
797
- status_output = gr.Textbox(label="📈 Status", lines=12)
798
-
799
- status_button.click(
800
- check_status_wrapper,
801
- inputs=job_id_input,
802
- outputs=status_output
803
- )
804
-
805
- gr.Markdown("💡 **Dica:** Atualize o status regularmente para acompanhar o progresso do treinamento.")
806
-
807
- # Aba de Modelos e Download
808
- with gr.TabItem("📚 Modelos e Download"):
809
- gr.Markdown("### Visualizar e Baixar Modelos LoRA Treinados")
810
-
811
- with gr.Row():
812
- with gr.Column(scale=1):
813
- list_button = gr.Button("📋 Listar Modelos", variant="secondary")
814
- models_output = gr.Textbox(label="📚 Modelos Disponíveis", lines=10)
815
-
816
- list_button.click(
817
- list_models_wrapper,
818
- outputs=models_output
819
- )
820
-
821
- with gr.Column(scale=1):
822
- gr.Markdown("#### 💾 Download de Modelo")
823
-
824
- download_job_id = gr.Textbox(
825
- label="🆔 ID do Job para Download",
826
- placeholder="Cole o ID do job concluído...",
827
- )
828
-
829
- download_button = gr.Button("📦 Preparar Download", variant="primary")
830
- download_file = gr.File(label="📁 Arquivo para Download")
831
- download_status = gr.Textbox(label="📊 Status do Download", lines=3)
832
-
833
- download_button.click(
834
- download_model_wrapper,
835
- inputs=download_job_id,
836
- outputs=[download_file, download_status]
837
- )
838
-
839
- # Aba de Informações
840
- with gr.TabItem("ℹ️ Sobre"):
841
- gr.Markdown("""
842
- ### 🎯 Sobre o LoRA Image Trainer
843
-
844
- Esta ferramenta foi desenvolvida para democratizar o acesso ao treinamento de modelos LoRA para geração de imagens,
845
- permitindo que qualquer pessoa possa criar adaptações personalizadas de modelos de difusão (como Stable Diffusion)
846
- sem a necessidade de hardware especializado.
847
-
848
- #### ✨ Características Principais:
849
-
850
- - **🔋 Otimizado para Baixa GPU**: Utiliza técnicas como mixed precision, gradient checkpointing e configurações otimizadas
851
- - **📱 Compatível com Móveis**: Interface responsiva que funciona em smartphones e tablets
852
- - **⚡ Rápido e Eficiente**: Treinamento otimizado com bibliotecas Diffusers e PEFT do Hugging Face
853
- - **🎛️ Configurável**: Controle total sobre parâmetros LoRA e de treinamento
854
- - **☁️ Pronto para Deploy**: Facilmente implantável no Hugging Face Spaces
855
- - **🎨 Focado em Imagens**: Especificamente projetado para modelos de difusão e geração de imagens
856
-
857
- #### 🛠️ Tecnologias Utilizadas:
858
-
859
- - **Hugging Face Diffusers**: Para modelos de difusão e pipeline de treinamento
860
- - **PEFT (Parameter-Efficient Fine-Tuning)**: Para treinamento eficiente de LoRA
861
- - **PyTorch**: Framework de deep learning
862
- - **Gradio**: Interface web interativa e responsiva
863
- - **LoRA (Low-Rank Adaptation)**: Técnica de fine-tuning eficiente para modelos de difusão
864
-
865
- #### 📖 Como Usar:
866
-
867
- 1. **Prepare suas imagens**: Colete 3-50 imagens de alta qualidade do estilo/conceito que deseja treinar
868
- 2. **Escolha um modelo base** na aba "Treinar LoRA" (recomendado: Stable Diffusion 1.5)
869
- 3. **Faça upload das imagens** e defina uma trigger word (palavra-chave)
870
- 4. **Configure os parâmetros** conforme necessário (valores padrão funcionam bem)
871
- 5. **Inicie o treinamento** e anote o ID do job
872
- 6. **Acompanhe o progresso** na aba "Status do Treinamento"
873
- 7. **Baixe seu LoRA** na aba "Modelos e Download" quando concluído
874
- 8. **Use em suas ferramentas favoritas** (ComfyUI, Automatic1111, etc.)
875
-
876
- #### 💡 Dicas para Melhores Resultados:
877
-
878
- - **Qualidade > Quantidade**: 10-20 imagens de alta qualidade são melhores que 50 imagens ruins
879
- - **Consistência**: Use imagens com estilo/conceito consistente
880
- - **Resolução**: Para GPUs com pouca VRAM, use resolução 512x512
881
- - **Trigger Word**: Escolha uma palavra única e fácil de lembrar
882
- - **Legendas**: Descreva o que há nas imagens para melhor controle
883
- - **Parâmetros**: Para iniciantes, use os valores padrão
884
-
885
- #### 🎮 Compatibilidade:
886
-
887
- Os LoRAs gerados são compatíveis com:
888
- - **ComfyUI**: Carregue os arquivos .safetensors
889
- - **Automatic1111**: Coloque na pasta models/Lora
890
- - **SeaArt**: Faça upload do modelo
891
- - **Outras ferramentas**: Qualquer ferramenta que suporte LoRA para Stable Diffusion
892
-
893
- ---
894
-
895
- **Desenvolvido com ❤️ para a comunidade de IA e arte digital**
896
- """)
897
-
898
- # Footer
899
- gr.Markdown("""
900
- ---
901
- <div style="text-align: center; color: #666; font-size: 0.9em;">
902
- 🎨 LoRA Image Trainer v1.0 | Otimizado para Baixa GPU | Compatível com Dispositivos Móveis
903
- </div>
904
- """)
905
 
906
  return interface
907
 
908
- # Criar e configurar interface
909
  if __name__ == "__main__":
910
- # Criar diretórios necessários
911
- os.makedirs("./lora_models", exist_ok=True)
912
-
913
- # Configurar interface
914
  interface = create_gradio_interface()
915
-
916
- # Lançar aplicação
917
  interface.launch(
918
  server_name="0.0.0.0",
919
  server_port=7860,
920
- share=False,
921
  show_error=True,
922
  quiet=False
923
  )
 
35
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  self.training_jobs = {}
37
  self.models_cache = {}
38
+ # Criar diretórios
39
+ os.makedirs("./lora_models", exist_ok=True)
40
+ logger.info("Diretório ./lora_models criado.")
41
+
42
  def get_available_models(self) -> List[str]:
 
43
  return [
44
  "runwayml/stable-diffusion-v1-5",
45
  "stabilityai/stable-diffusion-2-1",
46
  "CompVis/stable-diffusion-v1-4"
 
47
  ]
48
 
49
  def load_base_model(self, model_name: str):
 
50
  try:
51
  if model_name in self.models_cache:
52
  return self.models_cache[model_name]
53
 
54
  logger.info(f"Carregando modelo base: {model_name}")
55
 
 
56
  model_kwargs = {
57
  "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
58
  "use_safetensors": True,
59
  "variant": "fp16" if torch.cuda.is_available() else None,
60
+ "safety_checker": None,
61
  }
62
 
 
63
  pipeline = StableDiffusionPipeline.from_pretrained(
64
  model_name,
65
  **model_kwargs
 
67
 
68
  if torch.cuda.is_available():
69
  pipeline = pipeline.to(self.device)
 
70
  pipeline.enable_attention_slicing()
 
71
  try:
72
  pipeline.enable_xformers_memory_efficient_attention()
73
+ except:
74
+ logger.warning("xformers não disponível")
75
+ pipeline.unet.enable_gradient_checkpointing()
76
 
 
77
  self.models_cache[model_name] = pipeline
 
78
  return pipeline
79
 
80
  except Exception as e:
 
82
  raise e
83
 
84
  def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:
 
85
  dataset = []
86
 
87
  for img_path, caption in zip(image_files, captions):
88
  try:
 
89
  image = Image.open(img_path).convert("RGB")
 
 
90
  image = self.resize_image(image, resolution)
91
 
92
  dataset.append({
 
102
  return dataset
103
 
104
  def resize_image(self, image: Image.Image, target_size: int) -> Image.Image:
 
105
  width, height = image.size
106
 
 
107
  if width > height:
108
  new_width = target_size
109
  new_height = int((height * target_size) / width)
 
111
  new_height = target_size
112
  new_width = int((width * target_size) / height)
113
 
 
114
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
115
 
 
116
  if new_width != target_size or new_height != target_size:
117
  left = (new_width - target_size) // 2
118
  top = (new_height - target_size) // 2
119
  right = left + target_size
120
  bottom = top + target_size
 
121
  image = image.crop((left, top, right, bottom))
122
 
123
  return image
 
126
  job_id: str,
127
  model_name: str,
128
  dataset: List[Dict],
129
+ r: int = 8,
130
+ lora_alpha: int = 16,
131
+ lora_dropout: float = 0.0,
132
+ num_epochs: int = 5,
133
  learning_rate: float = 1e-4,
134
  batch_size: int = 1,
135
  resolution: int = 512) -> None:
 
 
136
  try:
137
+ # Inicializar job
138
+ if job_id not in self.training_jobs:
139
+ self.training_jobs[job_id] = {
140
+ "id": job_id,
141
+ "status": "queued",
142
+ "progress": 0,
143
+ "created_at": datetime.now().isoformat(),
144
+ "model_name": model_name,
145
+ "num_images": len(dataset),
146
+ "logs": [],
147
+ "error": None,
148
+ "model_path": None,
149
+ "completed_at": None
150
+ }
151
+
152
  self.training_jobs[job_id]["status"] = "loading_model"
153
  self.training_jobs[job_id]["progress"] = 5
154
+ log_msg = f"{datetime.now().strftime('%H:%M:%S')} - Carregando modelo base: {model_name}"
155
+ self.training_jobs[job_id]["logs"].append(log_msg)
156
+ logger.info(log_msg)
157
+
158
  pipeline = self.load_base_model(model_name)
159
  unet = pipeline.unet
160
  text_encoder = pipeline.text_encoder
 
162
  tokenizer = pipeline.tokenizer
163
  scheduler = pipeline.scheduler
164
 
 
165
  unet.requires_grad_(False)
166
  text_encoder.requires_grad_(False)
167
  vae.requires_grad_(False)
168
 
169
+ # Remover adaptador se existir
170
+ if hasattr(unet, "peft_config") and "default" in unet.peft_config:
171
+ unet.delete_adapter("default")
172
+
173
  lora_config = LoraConfig(
174
  r=r,
175
  lora_alpha=lora_alpha,
 
178
  bias="none"
179
  )
180
 
 
181
  unet.add_adapter(lora_config, adapter_name="default")
 
 
182
  unet.set_adapter("default")
183
+
184
+ # Ativar apenas parâmetros do LoRA
185
+ unet.requires_grad_(False)
186
+ for name, param in unet.named_parameters():
187
+ if "lora_" in name:
188
+ param.requires_grad = True
189
+
190
  unet.train()
191
  unet.to(self.device)
192
 
193
+ optimizer = torch.optim.AdamW([p for p in unet.parameters() if p.requires_grad], lr=learning_rate)
 
194
 
 
195
  self.training_jobs[job_id]["status"] = "preparing_data"
196
  self.training_jobs[job_id]["progress"] = 20
197
 
 
198
  def preprocess_image(image):
199
  image = np.array(image).astype(np.float32) / 255.0
200
  image = image.transpose(2, 0, 1)
201
  image = torch.from_numpy(image).unsqueeze(0)
202
  return image
203
 
 
204
  total_steps = num_epochs * len(dataset)
205
  current_step = 0
206
 
207
  self.training_jobs[job_id]["status"] = "training"
208
+ log_msg = f"{datetime.now().strftime('%H:%M:%S')} - Iniciando treinamento real..."
209
+ self.training_jobs[job_id]["logs"].append(log_msg)
210
+ logger.info(log_msg)
211
 
212
  for epoch in range(num_epochs):
213
  for item in dataset:
214
  current_step += 1
215
 
 
216
  image = item["image"]
217
  caption = item["caption"]
 
 
218
  image_tensor = preprocess_image(image).to(self.device)
219
  if torch.cuda.is_available():
220
  image_tensor = image_tensor.half()
221
 
 
222
  with torch.no_grad():
223
  latents = vae.encode(image_tensor * 2 - 1).latent_dist.sample() * 0.18215
224
 
 
225
  inputs = tokenizer(caption, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
226
  input_ids = inputs.input_ids.to(self.device)
227
 
 
228
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (1,), device=self.device).long()
 
 
229
  noise = torch.randn_like(latents)
230
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
231
 
 
232
  encoder_hidden_states = text_encoder(input_ids)[0]
233
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
234
 
 
235
  loss = torch.nn.functional.mse_loss(noise_pred, noise)
 
 
236
  optimizer.zero_grad()
237
  loss.backward()
238
  optimizer.step()
239
 
240
+ if torch.cuda.is_available():
241
+ torch.cuda.empty_cache()
242
+
243
  progress = 30 + int((current_step / total_steps) * 60)
244
  self.training_jobs[job_id]["progress"] = min(progress, 90)
245
 
246
  if current_step % max(1, len(dataset)//2) == 0:
247
  log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
248
+ self.training_jobs[job_id]["logs"].append(log_msg)
249
+ logger.info(log_msg)
250
 
 
251
  self.training_jobs[job_id]["status"] = "saving"
252
  self.training_jobs[job_id]["progress"] = 95
253
 
254
  output_dir = f"./lora_models/{job_id}"
255
  os.makedirs(output_dir, exist_ok=True)
256
 
257
+ # CORREÇÃO PRINCIPAL: SALVAR APENAS O ADAPTADOR LORA
258
+ unet.save_pretrained(
259
+ output_dir,
260
+ safe_serialization=True,
261
+ selected_adapters=["default"]
262
+ )
263
 
 
264
  lora_config_dict = {
265
  "r": r,
266
  "lora_alpha": lora_alpha,
267
  "target_modules": ["to_k", "to_q", "to_v", "to_out.0"],
268
  "lora_dropout": lora_dropout,
269
  "bias": "none",
270
+ "task_type": "CAUSAL_LM",
271
  "base_model_name": model_name,
272
  "training_info": {
273
  "num_epochs": num_epochs,
 
280
  with open(f"{output_dir}/adapter_config.json", "w") as f:
281
  json.dump(lora_config_dict, f, indent=2)
282
 
 
283
  readme_content = f"""# LoRA Model - {job_id}
284
+ Treinado com sucesso!
 
 
285
  Modelo Base: {model_name}
286
+ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  """
288
  with open(f"{output_dir}/README.md", "w") as f:
289
  f.write(readme_content)
290
 
 
291
  self.training_jobs[job_id]["status"] = "completed"
292
  self.training_jobs[job_id]["progress"] = 100
293
  self.training_jobs[job_id]["model_path"] = output_dir
294
  self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
295
+ log_msg = f"{datetime.now().strftime('%H:%M:%S')} - ✅ Treinamento concluído! LoRA salvo em {output_dir}"
296
+ self.training_jobs[job_id]["logs"].append(log_msg)
297
+ logger.info(log_msg)
298
 
299
  except Exception as e:
300
+ error_msg = f"Erro no treinamento: {str(e)}"
301
  logger.error(error_msg)
302
+ if job_id in self.training_jobs:
303
+ self.training_jobs[job_id]["status"] = "error"
304
+ self.training_jobs[job_id]["error"] = error_msg
305
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ❌ {error_msg}")
306
 
307
  def start_training(self,
308
  model_name: str,
309
  image_files: List[str],
310
  captions: List[str],
311
  **kwargs) -> str:
 
 
312
  job_id = str(uuid.uuid4())
313
 
 
314
  dataset = self.prepare_image_dataset(image_files, captions, kwargs.get('resolution', 512))
315
 
316
  self.training_jobs[job_id] = {
 
326
  "completed_at": None
327
  }
328
 
 
329
  thread = threading.Thread(
330
  target=self.real_lora_training,
331
  args=(job_id, model_name, dataset),
 
337
  return job_id
338
 
339
  def get_training_status(self, job_id: str) -> Dict[str, Any]:
 
340
  return self.training_jobs.get(job_id, {"error": "Job não encontrado"})
341
 
342
  def list_trained_models(self) -> List[Dict[str, str]]:
 
343
  models = []
344
  lora_models_dir = Path("./lora_models")
345
 
 
351
  try:
352
  with open(config_file, 'r') as f:
353
  config = json.load(f)
 
354
  models.append({
355
  "id": model_dir.name,
356
  "path": str(model_dir),
 
358
  "r": config.get("r", "Unknown"),
359
  "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
360
  })
361
+ except:
362
  models.append({
363
  "id": model_dir.name,
364
  "path": str(model_dir),
 
366
  "r": "Unknown",
367
  "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
368
  })
 
369
  return models
370
 
371
  def create_download_zip(self, model_path: str) -> str:
 
372
  zip_path = f"{model_path}.zip"
 
373
  with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
374
  model_dir = Path(model_path)
375
  for file_path in model_dir.rglob('*'):
376
  if file_path.is_file():
377
  arcname = file_path.relative_to(model_dir)
378
  zipf.write(file_path, arcname)
 
379
  return zip_path
380
 
381
 
 
382
  trainer = LoRAImageTrainer()
383
 
384
  def create_gradio_interface():
 
 
 
385
  custom_css = """
 
386
  @media (max-width: 768px) {
387
+ .gradio-container { padding: 8px !important; }
388
+ .btn { width: 100% !important; padding: 12px !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  }
 
 
390
  .lora-header {
391
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
392
+ color: white; padding: 20px; border-radius: 12px; margin-bottom: 20px;
 
 
 
 
 
393
  }
394
+ .status-indicator { padding: 4px 8px; border-radius: 6px; font-size: 12px; font-weight: 600; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  .status-completed { background-color: #34d399; color: #065f46; }
396
  .status-error { background-color: #f87171; color: #991b1b; }
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  """
398
 
399
+ def start_training_wrapper(model_name, files, captions_text, trigger_word, r, lora_alpha, num_epochs, learning_rate):
400
+ if not files: return "❌ Erro: Nenhuma imagem enviada!"
401
+ if len(files) < 3: return "❌ Forneça pelo menos 3 imagens!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
  try:
 
404
  image_files = [f.name for f in files]
405
+ captions = [line.strip() for line in captions_text.split('\n') if line.strip()] if captions_text.strip() else []
 
 
 
 
 
 
406
  while len(captions) < len(files):
407
+ captions.append(f"{trigger_word.strip() or 'training image'}, high quality photo" if trigger_word.strip() else f"training image {len(captions) + 1}, high quality photo")
 
 
 
 
 
408
  captions = captions[:len(files)]
409
 
410
  job_id = trainer.start_training(
411
+ model_name=model_name,
412
+ image_files=image_files,
413
  captions=captions,
414
+ r=int(r),
415
+ lora_alpha=int(lora_alpha),
416
+ lora_dropout=0.0, # Fixado
417
+ num_epochs=int(num_epochs),
418
  learning_rate=float(learning_rate),
419
+ batch_size=1, # Fixado
420
+ resolution=512 # Fixado
421
  )
422
+ return f"✅ Treinamento iniciado! ID: {job_id}\n📊 Imagens: {len(files)}\nUse este ID para acompanhar o progresso."
 
 
423
  except Exception as e:
424
+ return f"❌ Erro: {str(e)}"
425
 
426
  def check_status_wrapper(job_id):
427
+ if not job_id.strip(): return " Forneça um ID válido!"
 
 
 
428
  status = trainer.get_training_status(job_id.strip())
429
+ if "error" in status: return "❌ Job não encontrado!"
430
 
431
+ status_emoji = {'completed': '✅', 'error': '❌', 'training': '🏋️', 'queued': '⏳'}.get(status['status'], '📊')
432
+ progress = status['progress']
433
+ progress_bar = f'<div style="width:100%;background:#e5e7eb;border-radius:4px;overflow:hidden;margin:8px 0;"><div style="width:{progress}%;height:8px;background:linear-gradient(90deg,#3b82f6,#8b5cf6);border-radius:4px;"></div></div>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
+ status_text = f"🆔 Job ID: {status['id']}\n{status_emoji} Status: {status['status'].upper()}\n⏳ Progresso: {progress}%\n{progress_bar}\n🤖 Modelo: {status['model_name']}\n🖼️ Imagens: {status.get('num_images','N/A')}\n📅 Criado: {status['created_at']}\n"
436
+ if status['logs']: status_text += "\n📝 Logs:\n" + "\n".join([f"• {log}" for log in status['logs'][-5:]])
437
+ if status['status'] == 'completed': status_text += f"\n✅ Concluído! Modelo salvo em: {status['model_path']}"
438
+ elif status['status'] == 'error': status_text += f"\n❌ Erro: {status['error']}"
439
  return status_text
440
 
441
  def list_models_wrapper():
 
442
  models = trainer.list_trained_models()
443
+ if not models: return "📭 Nenhum modelo encontrado."
444
+ return "\n\n".join([f"🆔 {m['id']}\n🤖 {m['base_model']}\n📅 {m['created']}" for m in models])
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  def download_model_wrapper(job_id):
447
+ if not job_id.strip(): return None, "❌ ID inválido!"
 
 
 
448
  status = trainer.get_training_status(job_id.strip())
449
+ if status.get("status") != "completed": return None, "❌ Treinamento não concluído!"
 
 
 
 
 
 
450
  try:
451
+ zip_path = trainer.create_download_zip(status['model_path'])
452
+ return zip_path, "✅ ZIP criado! Clique acima para baixar."
 
 
 
453
  except Exception as e:
454
+ return None, f"❌ Erro: {str(e)}"
455
 
456
+ with gr.Blocks(title="🎨 LoRA Image Trainer", theme=gr.themes.Soft(), css=custom_css) as interface:
457
+ gr.HTML('<div class="lora-header"><h1>🎨 LoRA Image Trainer</h1><p>Treine seu LoRA para imagens</p></div>')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
 
459
  with gr.Tabs():
460
+ with gr.TabItem("🎯 Treinar"):
 
 
 
 
461
  with gr.Row():
462
+ with gr.Column():
463
+ model_dropdown = gr.Dropdown(trainer.get_available_models(), value="runwayml/stable-diffusion-v1-5", label="🤖 Modelo Base")
464
+ image_files = gr.File(file_count="multiple", file_types=["image"], label="🖼️ Imagens")
465
+ trigger_word = gr.Textbox(label="🏷️ Trigger Word", placeholder="ex: meuEstilo")
466
+ captions_text = gr.Textbox(lines=4, placeholder="Legenda por linha...", label="📝 Legendas (Opcional)")
467
+ with gr.Column():
468
+ r = gr.Slider(4, 32, 8, step=4, label="r (Rank)")
469
+ lora_alpha = gr.Slider(1, 32, 16, step=1, label="LoRA Alpha")
470
+ num_epochs = gr.Slider(1, 10, 5, step=1, label="Épocas")
471
+ learning_rate = gr.Slider(1e-5, 1e-3, 1e-4, step=1e-5, label="Taxa de Aprendizado")
472
+ train_button = gr.Button("🚀 Iniciar Treinamento", variant="primary")
473
+ train_output = gr.Textbox(label="📊 Resultado")
474
+ train_button.click(start_training_wrapper, [model_dropdown, image_files, captions_text, trigger_word, r, lora_alpha, num_epochs, learning_rate], train_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
 
476
+ with gr.TabItem("📊 Status"):
477
+ job_id_input = gr.Textbox(label="🆔 ID do Job")
478
+ status_button = gr.Button("🔍 Verificar Status")
479
+ status_output = gr.Textbox(label="📈 Status", lines=10)
480
+ status_button.click(check_status_wrapper, job_id_input, status_output)
 
 
 
 
481
 
482
+ with gr.TabItem("📚 Download"):
483
+ download_job_id = gr.Textbox(label="🆔 ID do Job")
484
+ download_button = gr.Button("📦 Baixar Modelo")
485
+ download_file = gr.File(label="📁 Arquivo ZIP")
486
+ download_status = gr.Textbox(label="📊 Status")
487
+ download_button.click(download_model_wrapper, download_job_id, [download_file, download_status])
488
+
489
+ gr.Markdown("---\n<center>🎨 LoRA Image Trainer v1.0 | Treinamento Real de LoRA</center>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
 
491
  return interface
492
 
 
493
  if __name__ == "__main__":
 
 
 
 
494
  interface = create_gradio_interface()
 
 
495
  interface.launch(
496
  server_name="0.0.0.0",
497
  server_port=7860,
 
498
  show_error=True,
499
  quiet=False
500
  )