Allex21 commited on
Commit
119d0f5
·
verified ·
1 Parent(s): 8f49e8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +755 -726
app.py CHANGED
@@ -30,136 +30,137 @@ logger = logging.getLogger(__name__)
30
 
31
  class LoRAImageTrainer:
32
  """Classe principal para treinamento de modelos LoRA para geração de imagens otimizada para baixo uso de GPU."""
33
-
34
- def __init__(self):
35
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
- self.training_jobs = {}
37
- self.models_cache = {}
38
-
39
- def get_available_models(self) -> List[str]:
40
- """Retorna lista de modelos base disponíveis para treinamento LoRA."""
41
- return [
42
- "runwayml/stable-diffusion-v1-5",
43
- "stabilityai/stable-diffusion-2-1",
44
- "stabilityai/stable-diffusion-xl-base-1.0",
45
- "CompVis/stable-diffusion-v1-4"
46
- ]
47
-
48
- def load_base_model(self, model_name: str):
49
- """Carrega modelo base de difusão com otimizações para baixo uso de GPU."""
50
- try:
51
- if model_name in self.models_cache:
52
- return self.models_cache[model_name]
53
-
54
- logger.info(f"Carregando modelo base: {model_name}")
55
-
56
- # Configurações para otimização de memória
57
- model_kwargs = {
58
- "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
59
- "use_safetensors": True,
60
- "variant": "fp16" if torch.cuda.is_available() else None,
61
- }
62
-
63
- # Carregar pipeline completo
64
- pipeline = StableDiffusionPipeline.from_pretrained(
65
- model_name,
66
- **model_kwargs
67
- )
68
-
69
- if torch.cuda.is_available():
70
- pipeline = pipeline.to(self.device)
71
- # Habilitar attention slicing para economia de memória
72
- pipeline.enable_attention_slicing()
73
- # Habilitar memory efficient attention se disponível
74
- try:
75
- pipeline.enable_xformers_memory_efficient_attention()
76
- except:
77
- logger.warning("xformers não disponível, usando attention padrão")
78
-
79
- # Cache do modelo
80
- self.models_cache[model_name] = pipeline
81
-
82
- return pipeline
83
-
84
- except Exception as e:
85
- logger.error(f"Erro ao carregar modelo {model_name}: {str(e)}")
86
- raise e
87
-
88
- def create_lora_config(self,
89
- r: int = 16,
90
- lora_alpha: int = 32,
91
- lora_dropout: float = 0.1,
92
- target_modules: Optional[List[str]] = None) -> LoraConfig:
93
- """Cria configuração LoRA otimizada para modelos de difusão."""
94
-
95
- if target_modules is None:
96
- # Módulos padrão para UNet do Stable Diffusion
97
- target_modules = [
98
- "to_k", "to_q", "to_v", "to_out.0",
99
- "proj_in", "proj_out",
100
- "ff.net.0.proj", "ff.net.2"
101
- ]
102
-
103
- return LoraConfig(
104
- r=r,
105
- lora_alpha=lora_alpha,
106
- target_modules=target_modules,
107
- lora_dropout=lora_dropout,
108
- bias="none",
109
- task_type=TaskType.DIFFUSION,
110
- )
111
-
112
- def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:
113
- """Prepara dataset de imagens para treinamento."""
114
- dataset = []
115
-
116
- for img_path, caption in zip(image_files, captions):
117
- try:
118
- # Carregar e redimensionar imagem
119
- image = Image.open(img_path).convert("RGB")
120
-
121
- # Redimensionar mantendo aspect ratio
122
- image = self.resize_image(image, resolution)
123
-
124
- dataset.append({
125
- "image": image,
126
- "caption": caption,
127
- "image_path": img_path
128
- })
129
-
130
- except Exception as e:
131
- logger.error(f"Erro ao processar imagem {img_path}: {str(e)}")
132
- continue
133
-
134
- return dataset
135
-
136
- def resize_image(self, image: Image.Image, target_size: int) -> Image.Image:
137
- """Redimensiona imagem mantendo aspect ratio e fazendo crop central se necessário."""
138
- width, height = image.size
139
-
140
- # Calcular novo tamanho mantendo aspect ratio
141
- if width > height:
142
- new_width = target_size
143
- new_height = int((height * target_size) / width)
144
- else:
145
- new_height = target_size
146
- new_width = int((width * target_size) / height)
147
-
148
- # Redimensionar
149
- image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
150
-
151
- # Crop central para obter tamanho exato
152
- if new_width != target_size or new_height != target_size:
153
- left = (new_width - target_size) // 2
154
- top = (new_height - target_size) // 2
155
- right = left + target_size
156
- bottom = top + target_size
157
-
158
- image = image.crop((left, top, right, bottom))
159
-
 
160
  return image
161
-
162
- def simulate_training(self,
163
  job_id: str,
164
  model_name: str,
165
  dataset: List[Dict],
@@ -170,70 +171,117 @@ class LoRAImageTrainer:
170
  learning_rate: float = 1e-4,
171
  batch_size: int = 1,
172
  resolution: int = 512) -> None:
173
- """Simula o processo de treinamento LoRA para imagens (versão demonstrativa)."""
174
 
175
  try:
176
  # Atualizar status
177
  self.training_jobs[job_id]["status"] = "loading_model"
178
  self.training_jobs[job_id]["progress"] = 5
179
-
180
- # Simular carregamento do modelo base
181
- time.sleep(2)
182
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Modelo {model_name} carregado")
183
-
184
- # Preparar configuração LoRA
185
- self.training_jobs[job_id]["status"] = "preparing_lora"
186
- self.training_jobs[job_id]["progress"] = 15
187
- time.sleep(1)
188
-
 
 
 
 
 
 
189
  lora_config = self.create_lora_config(r, lora_alpha, lora_dropout)
190
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Configuração LoRA criada (r={r}, alpha={lora_alpha})")
191
-
192
- # Preparar dataset
 
 
 
 
 
193
  self.training_jobs[job_id]["status"] = "preparing_data"
194
- self.training_jobs[job_id]["progress"] = 25
195
- time.sleep(1)
196
-
197
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Dataset preparado com {len(dataset)} imagens")
198
-
199
- # Simular treinamento
200
- self.training_jobs[job_id]["status"] = "training"
201
- self.training_jobs[job_id]["progress"] = 30
202
-
 
203
  total_steps = num_epochs * len(dataset)
204
  current_step = 0
205
-
 
 
 
206
  for epoch in range(num_epochs):
207
- for batch_idx in range(len(dataset)):
208
  current_step += 1
209
 
210
- # Simular tempo de processamento
211
- time.sleep(0.5)
212
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  # Atualizar progresso
214
  progress = 30 + int((current_step / total_steps) * 60)
215
  self.training_jobs[job_id]["progress"] = min(progress, 90)
216
-
217
- # Simular loss decrescente
218
- loss = 0.8 - (current_step / total_steps) * 0.6
219
-
220
- if current_step % 5 == 0: # Log a cada 5 steps
221
- log_message = f"Época {epoch+1}/{num_epochs}, Step {current_step}/{total_steps} - Loss: {loss:.4f}"
222
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - {log_message}")
223
-
224
- # Salvar modelo LoRA
225
  self.training_jobs[job_id]["status"] = "saving"
226
  self.training_jobs[job_id]["progress"] = 95
227
- time.sleep(1)
228
-
229
  output_dir = f"./lora_models/{job_id}"
230
  os.makedirs(output_dir, exist_ok=True)
231
-
232
- # Criar arquivos simulados do LoRA
 
 
 
233
  lora_config_dict = {
234
  "r": r,
235
  "lora_alpha": lora_alpha,
236
- "target_modules": ["to_k", "to_q", "to_v", "to_out.0"],
237
  "lora_dropout": lora_dropout,
238
  "bias": "none",
239
  "task_type": "DIFFUSION",
@@ -246,646 +294,627 @@ class LoRAImageTrainer:
246
  "num_images": len(dataset)
247
  }
248
  }
249
-
250
  with open(f"{output_dir}/adapter_config.json", "w") as f:
251
  json.dump(lora_config_dict, f, indent=2)
252
-
253
- # Simular arquivo de pesos LoRA
254
- with open(f"{output_dir}/adapter_model.safetensors", "w") as f:
255
- f.write("# Arquivo simulado do modelo LoRA treinado para geração de imagens")
256
-
257
- # Criar arquivo README com informações do treinamento
258
  readme_content = f"""# LoRA Model - {job_id}
259
 
260
- ## Informações do Treinamento
261
 
262
- - **Modelo Base**: {model_name}
263
- - **Rank (r)**: {r}
264
- - **LoRA Alpha**: {lora_alpha}
265
- - **Dropout**: {lora_dropout}
266
- - **Épocas**: {num_epochs}
267
- - **Taxa de Aprendizado**: {learning_rate}
268
- - **Resolução**: {resolution}x{resolution}
269
- - **Número de Imagens**: {len(dataset)}
270
- - **Data de Treinamento**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
271
 
272
- ## Como Usar
273
 
274
- 1. Baixe os arquivos `adapter_config.json` e `adapter_model.safetensors`
275
  2. Carregue em sua ferramenta de geração de imagens favorita (ComfyUI, Automatic1111, etc.)
276
  3. Use o trigger word ou estilo aprendido durante o treinamento
277
-
278
- ## Arquivos
279
-
280
- - `adapter_config.json`: Configuração do LoRA
281
- - `adapter_model.safetensors`: Pesos do modelo LoRA
282
- - `README.md`: Este arquivo com informações do treinamento
283
  """
284
-
285
  with open(f"{output_dir}/README.md", "w") as f:
286
  f.write(readme_content)
287
-
288
  # Finalizar
289
  self.training_jobs[job_id]["status"] = "completed"
290
  self.training_jobs[job_id]["progress"] = 100
291
  self.training_jobs[job_id]["model_path"] = output_dir
292
  self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
293
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Treinamento concluído! LoRA salvo em {output_dir}")
294
-
295
- logger.info(f"Treinamento LoRA concluído para job {job_id}")
296
-
297
  except Exception as e:
298
- logger.error(f"Erro no treinamento LoRA para job {job_id}: {str(e)}")
 
299
  self.training_jobs[job_id]["status"] = "error"
300
- self.training_jobs[job_id]["error"] = str(e)
301
-
302
- def start_training(self,
 
303
  model_name: str,
304
  image_files: List[str],
305
  captions: List[str],
306
  **kwargs) -> str:
307
  """Inicia treinamento LoRA assíncrono."""
308
-
309
- job_id = str(uuid.uuid4())
310
-
311
- # Preparar dataset
312
- dataset = self.prepare_image_dataset(image_files, captions, kwargs.get('resolution', 512))
313
-
314
- self.training_jobs[job_id] = {
315
- "id": job_id,
316
- "status": "queued",
317
- "progress": 0,
318
- "created_at": datetime.now().isoformat(),
319
- "model_name": model_name,
320
- "num_images": len(dataset),
321
- "logs": [],
322
- "error": None,
323
- "model_path": None,
324
- "completed_at": None
325
- }
326
-
327
- # Iniciar treinamento em thread separada
328
- thread = threading.Thread(
329
- target=self.simulate_training,
330
- args=(job_id, model_name, dataset),
331
- kwargs=kwargs
332
- )
333
- thread.daemon = True
334
- thread.start()
335
-
336
  return job_id
337
-
338
  def get_training_status(self, job_id: str) -> Dict[str, Any]:
339
  """Retorna status do treinamento."""
340
  return self.training_jobs.get(job_id, {"error": "Job não encontrado"})
341
-
342
  def list_trained_models(self) -> List[Dict[str, str]]:
343
  """Lista modelos LoRA treinados."""
344
  models = []
345
  lora_models_dir = Path("./lora_models")
346
-
347
- if lora_models_dir.exists():
348
- for model_dir in lora_models_dir.iterdir():
349
- if model_dir.is_dir():
350
- config_file = model_dir / "adapter_config.json"
351
- if config_file.exists():
352
- try:
353
- with open(config_file, 'r') as f:
354
- config = json.load(f)
355
-
356
- models.append({
357
- "id": model_dir.name,
358
- "path": str(model_dir),
359
- "base_model": config.get("base_model_name", "Unknown"),
360
- "r": config.get("r", "Unknown"),
361
- "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
362
- })
363
- except:
364
- models.append({
365
- "id": model_dir.name,
366
- "path": str(model_dir),
367
- "base_model": "Unknown",
368
- "r": "Unknown",
369
- "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
370
- })
371
-
372
  return models
373
-
374
  def create_download_zip(self, model_path: str) -> str:
375
  """Cria um arquivo ZIP com os arquivos do modelo LoRA para download."""
376
  zip_path = f"{model_path}.zip"
377
-
378
- with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
379
- model_dir = Path(model_path)
380
- for file_path in model_dir.rglob('*'):
381
- if file_path.is_file():
382
- arcname = file_path.relative_to(model_dir)
383
- zipf.write(file_path, arcname)
384
-
385
  return zip_path
386
 
 
387
  # Instância global do trainer
388
  trainer = LoRAImageTrainer()
389
 
390
  def create_gradio_interface():
391
  """Cria interface Gradio para a ferramenta LoRA de geração de imagens."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
- # CSS personalizado para responsividade móvel
394
- custom_css = """
395
- /* Mobile-first responsive design */
396
- @media (max-width: 768px) {
397
- .gradio-container {
398
- padding: 8px !important;
399
- margin: 0 !important;
400
- }
401
-
402
- .tab-nav {
403
- flex-wrap: wrap !important;
404
- gap: 4px !important;
405
- }
406
-
407
- .tab-nav button {
408
- font-size: 14px !important;
409
- padding: 8px 12px !important;
410
- min-width: auto !important;
411
- flex: 1 1 auto !important;
412
- }
413
-
414
- .form-container {
415
- padding: 12px !important;
416
- }
417
-
418
- .btn {
419
- width: 100% !important;
420
- padding: 12px !important;
421
- font-size: 16px !important;
422
- margin-bottom: 8px !important;
423
- min-height: 44px !important;
424
- }
425
-
426
- .textbox textarea {
427
- font-size: 16px !important;
428
- min-height: 120px !important;
429
- }
430
-
431
- .dropdown select {
432
- font-size: 16px !important;
433
- padding: 12px !important;
434
- }
435
-
436
- .output-text {
437
- font-size: 14px !important;
438
- line-height: 1.5 !important;
439
- }
440
-
441
- .column {
442
- margin-bottom: 16px !important;
443
- }
444
-
445
- .file-upload {
446
- min-height: 100px !important;
447
- }
448
- }
449
-
450
- /* Enhanced visual styles */
451
- .lora-header {
452
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
453
- color: white;
454
- padding: 20px;
455
- border-radius: 12px;
456
- margin-bottom: 20px;
457
- text-align: center;
458
- box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
459
- }
460
-
461
- .status-indicator {
462
- display: inline-block;
463
- padding: 4px 8px;
464
- border-radius: 6px;
465
- font-size: 12px;
466
- font-weight: 600;
467
- text-transform: uppercase;
468
- letter-spacing: 0.5px;
469
- margin-right: 8px;
470
- }
471
-
472
- .status-queued { background-color: #fbbf24; color: #92400e; }
473
- .status-loading_model { background-color: #60a5fa; color: #1e40af; }
474
- .status-preparing_lora { background-color: #8b5cf6; color: #5b21b6; }
475
- .status-preparing_data { background-color: #06b6d4; color: #0e7490; }
476
- .status-training { background-color: #a78bfa; color: #5b21b6; }
477
- .status-saving { background-color: #f59e0b; color: #92400e; }
478
- .status-completed { background-color: #34d399; color: #065f46; }
479
- .status-error { background-color: #f87171; color: #991b1b; }
480
-
481
- /* Touch device optimizations */
482
- @media (hover: none) and (pointer: coarse) {
483
- .btn {
484
- min-height: 44px !important;
485
- min-width: 44px !important;
486
- }
487
-
488
- .tab-nav button {
489
- min-height: 44px !important;
490
- min-width: 44px !important;
491
- }
492
- }
493
- """
494
-
495
- def process_images_and_captions(files, captions_text):
496
- """Processa imagens e legendas enviadas pelo usuário."""
497
- if not files:
498
- return "❌ Erro: Nenhuma imagem foi enviada!"
499
 
500
- # Processar legendas
501
- captions = []
502
- if captions_text.strip():
503
- captions = [line.strip() for line in captions_text.split('\n') if line.strip()]
504
 
505
- # Se não há legendas suficientes, usar legendas padrão
506
- while len(captions) < len(files):
507
- captions.append(f"training image {len(captions) + 1}")
508
 
509
- # Truncar legendas se houver mais que imagens
510
- captions = captions[:len(files)]
511
 
512
  return files, captions
513
 
514
- def start_training_wrapper(model_name, files, captions_text, trigger_word, r, lora_alpha, lora_dropout,
515
- num_epochs, learning_rate, batch_size, resolution):
516
- """Wrapper para iniciar treinamento via Gradio."""
517
-
518
- if not files:
519
- return "❌ Erro: Nenhuma imagem foi enviada para treinamento!"
520
-
521
- if len(files) < 3:
522
- return "❌ Erro: Forneça pelo menos 3 imagens para treinamento!"
523
-
524
- try:
525
- # Processar imagens e legendas
526
- image_files = [f.name for f in files]
527
-
528
- # Processar legendas
529
- captions = []
530
- if captions_text.strip():
531
- captions = [line.strip() for line in captions_text.split('\n') if line.strip()]
532
-
533
- # Se não há legendas suficientes, usar trigger word + descrição padrão
534
- while len(captions) < len(files):
535
- if trigger_word.strip():
536
- captions.append(f"{trigger_word.strip()}, high quality photo")
537
- else:
538
- captions.append(f"training image {len(captions) + 1}, high quality photo")
539
-
540
- # Truncar legendas se houver mais que imagens
541
- captions = captions[:len(files)]
542
-
543
- job_id = trainer.start_training(
544
- model_name=model_name,
545
- image_files=image_files,
546
- captions=captions,
547
- r=int(r),
548
- lora_alpha=int(lora_alpha),
549
- lora_dropout=float(lora_dropout),
550
- num_epochs=int(num_epochs),
551
- learning_rate=float(learning_rate),
552
- batch_size=int(batch_size),
553
- resolution=int(resolution)
554
- )
555
-
556
- return f"✅ Treinamento iniciado! ID do Job: {job_id}\n\n📊 Imagens: {len(files)}\n🏷️ Trigger Word: {trigger_word or 'Nenhuma'}\n\nUse o ID acima para verificar o progresso na aba 'Status do Treinamento'."
557
-
558
- except Exception as e:
559
- return f"❌ Erro ao iniciar treinamento: {str(e)}"
560
 
561
- def check_status_wrapper(job_id):
562
- """Wrapper para verificar status via Gradio."""
563
- if not job_id.strip():
564
- return "❌ Erro: Forneça um ID de job válido!"
565
-
566
- status = trainer.get_training_status(job_id.strip())
567
-
568
- if "error" in status and status["error"] == "Job não encontrado":
569
- return "❌ Job não encontrado! Verifique o ID."
570
-
571
- # Criar indicador visual de status
572
- status_class = f"status-{status['status']}"
573
- status_emoji = {
574
- 'queued': '⏳',
575
- 'loading_model': '📥',
576
- 'preparing_lora': '⚙️',
577
- 'preparing_data': '📊',
578
- 'training': '🏋️',
579
- 'saving': '💾',
580
- 'completed': '✅',
581
- 'error': '❌'
582
- }.get(status['status'], '📊')
583
-
584
- # Barra de progresso visual
585
- progress = status['progress']
586
- progress_bar = f"""
587
- <div style="width: 100%; background-color: #e5e7eb; border-radius: 4px; overflow: hidden; margin: 8px 0;">
588
- <div style="width: {progress}%; height: 8px; background: linear-gradient(90deg, #3b82f6, #8b5cf6); transition: width 0.3s ease; border-radius: 4px;"></div>
589
- </div>
590
- """
591
 
592
  status_text = f"""
593
- 📊 **Status do Treinamento LoRA**
594
 
595
- 🆔 **Job ID:** {status['id']}
596
- {status_emoji} **Status:** <span class="{status_class}">{status['status'].upper().replace('_', ' ')}</span>
597
- **Progresso:** {status['progress']}%
 
 
598
 
599
  {progress_bar}
600
 
601
- 🤖 **Modelo Base:** {status['model_name']}
602
- 🖼️ **Imagens:** {status.get('num_images', 'N/A')}
603
- 📅 **Criado em:** {status['created_at']}
604
 
605
  """
606
-
607
- if status['logs']:
608
- status_text += "📝 **Logs Recentes:**\n"
609
- for log in status['logs'][-5:]: # Últimos 5 logs
610
- status_text += f"• {log}\n"
611
-
612
- if status['status'] == 'completed':
613
- status_text += f"\n✅ **Treinamento Concluído!**\n📁 **Modelo salvo em:** {status['model_path']}"
614
- status_text += f"\n⏰ **Concluído em:** {status['completed_at']}"
615
- status_text += f"\n\n💡 **Próximos passos:** Vá para a aba 'Modelos Treinados' para baixar seu LoRA!"
616
- elif status['status'] == 'error':
617
- status_text += f"\n❌ **Erro:** {status['error']}"
618
 
619
  return status_text
620
 
621
- def list_models_wrapper():
622
- """Wrapper para listar modelos via Gradio."""
623
- models = trainer.list_trained_models()
624
-
625
- if not models:
626
- return "📭 Nenhum modelo LoRA treinado encontrado."
627
-
628
- models_text = "📚 **Modelos LoRA Treinados:**\n\n"
629
- for model in models:
630
- models_text += f"🆔 **ID:** {model['id']}\n"
631
- models_text += f"🤖 **Modelo Base:** {model['base_model']}\n"
632
- models_text += f"📊 **Rank (r):** {model['r']}\n"
633
- models_text += f"📁 **Caminho:** {model['path']}\n"
634
- models_text += f"📅 **Criado:** {model['created']}\n\n"
635
- models_text += "---\n\n"
636
 
637
  return models_text
638
 
639
- def download_model_wrapper(job_id):
640
- """Wrapper para preparar download do modelo."""
641
- if not job_id.strip():
642
- return None, "❌ Erro: Forneça um ID de job válido!"
643
 
644
- status = trainer.get_training_status(job_id.strip())
645
 
646
- if "error" in status and status["error"] == "Job não encontrado":
647
- return None, "❌ Job não encontrado! Verifique o ID."
648
 
649
- if status['status'] != 'completed':
650
- return None, f"❌ Treinamento ainda não foi concluído. Status atual: {status['status']}"
651
 
652
- try:
653
- model_path = status['model_path']
654
- zip_path = trainer.create_download_zip(model_path)
655
 
656
- return zip_path, f"✅ Arquivo ZIP criado com sucesso! Clique no link acima para baixar."
657
 
658
- except Exception as e:
659
- return None, f"❌ Erro ao criar arquivo de download: {str(e)}"
660
 
661
- # Interface Gradio
662
- with gr.Blocks(
663
- title="🎨 LoRA Image Trainer - Criador e Treinador de LoRA para Imagens",
664
- theme=gr.themes.Soft(),
665
- css=custom_css
666
- ) as interface:
667
-
668
- gr.HTML("""
669
- <div class="lora-header">
670
- <h1>🎨 LoRA Image Trainer</h1>
671
- <p>Criador e Treinador de LoRA para Geração de Imagens</p>
672
- <p style="font-size: 0.9em; opacity: 0.9; margin-top: 8px;">
673
- Ferramenta otimizada para baixo uso de GPU, compatível com dispositivos móveis
674
- </p>
675
- </div>
676
- """)
677
-
678
- with gr.Tabs():
679
-
680
- # Aba de Treinamento
681
- with gr.TabItem("🎯 Treinar LoRA"):
682
- gr.Markdown("### Configurar e Iniciar Treinamento LoRA para Imagens")
683
 
684
- with gr.Row():
685
- with gr.Column(scale=2):
686
- model_dropdown = gr.Dropdown(
687
- choices=trainer.get_available_models(),
688
- value="runwayml/stable-diffusion-v1-5",
689
- label="🤖 Modelo Base",
690
-
691
- )
692
 
693
- image_files = gr.File(
694
- file_count="multiple",
695
- file_types=["image"],
696
- label="🖼️ Imagens de Treinamento",
697
-
698
- )
699
 
700
- trigger_word = gr.Textbox(
701
- label="🏷️ Trigger Word (Opcional)",
702
- placeholder="ex: meuEstilo, minhaPersonagem, etc.",
703
-
704
- )
705
 
706
- captions_text = gr.Textbox(
707
- lines=8,
708
- placeholder="Digite uma legenda por linha (opcional)...\n\nExemplo:\nmeuEstilo, retrato de uma mulher\nmeuEstilo, homem sorrindo\nmeuEstilo, paisagem urbana\n\nSe deixar vazio, usará a trigger word + 'high quality photo'",
709
- label="📝 Legendas das Imagens (Opcional)",
710
-
711
- )
712
 
713
- with gr.Column(scale=1):
714
- gr.Markdown("### ⚙️ Parâmetros LoRA")
715
 
716
- r = gr.Slider(
717
- minimum=4, maximum=128, value=16, step=4,
718
- label="r (Rank)",
719
-
720
- )
721
 
722
- lora_alpha = gr.Slider(
723
- minimum=1, maximum=128, value=32, step=1,
724
- label="LoRA Alpha",
725
-
726
- )
727
 
728
- lora_dropout = gr.Slider(
729
- minimum=0.0, maximum=0.5, value=0.1, step=0.05,
730
- label="LoRA Dropout",
731
-
732
- )
733
 
734
- gr.Markdown("### 🏋️ Parâmetros de Treinamento")
735
 
736
- num_epochs = gr.Slider(
737
- minimum=5, maximum=50, value=10, step=5,
738
- label="Épocas",
739
-
740
- )
741
 
742
- learning_rate = gr.Slider(
743
- minimum=1e-5, maximum=1e-3, value=1e-4, step=1e-5,
744
- label="Taxa de Aprendizado",
745
-
746
- )
747
 
748
- batch_size = gr.Slider(
749
- minimum=1, maximum=8, value=1, step=1,
750
- label="Batch Size",
751
-
752
- )
753
 
754
- resolution = gr.Dropdown(
755
- choices=[512, 768, 1024],
756
- value=512,
757
- label="Resolução",
758
-
759
- )
760
-
761
- train_button = gr.Button("🚀 Iniciar Treinamento LoRA", variant="primary", size="lg")
762
- train_output = gr.Textbox(label="📊 Resultado", lines=5)
763
 
764
- train_button.click(
765
- start_training_wrapper,
766
- inputs=[model_dropdown, image_files, captions_text, trigger_word, r, lora_alpha, lora_dropout,
767
- num_epochs, learning_rate, batch_size, resolution],
768
- outputs=train_output
769
- )
770
-
771
- # Aba de Status
772
- with gr.TabItem("📊 Status do Treinamento"):
773
- gr.Markdown("### Verificar Progresso do Treinamento")
774
 
775
- job_id_input = gr.Textbox(
776
- label="🆔 ID do Job",
777
- placeholder="Cole aqui o ID do job de treinamento...",
778
-
779
- )
780
 
781
- status_button = gr.Button("🔍 Verificar Status", variant="secondary")
782
- status_output = gr.Textbox(label="📈 Status", lines=12)
783
 
784
- status_button.click(
785
- check_status_wrapper,
786
- inputs=job_id_input,
787
- outputs=status_output
788
- )
789
 
790
- gr.Markdown("💡 **Dica:** Atualize o status regularmente para acompanhar o progresso do treinamento.")
791
 
792
- # Aba de Modelos e Download
793
- with gr.TabItem("📚 Modelos e Download"):
794
- gr.Markdown("### Visualizar e Baixar Modelos LoRA Treinados")
795
 
796
- with gr.Row():
797
- with gr.Column(scale=1):
798
- list_button = gr.Button("📋 Listar Modelos", variant="secondary")
799
- models_output = gr.Textbox(label="📚 Modelos Disponíveis", lines=10)
800
 
801
- list_button.click(
802
- list_models_wrapper,
803
- outputs=models_output
804
- )
805
 
806
- with gr.Column(scale=1):
807
- gr.Markdown("#### 💾 Download de Modelo")
808
 
809
- download_job_id = gr.Textbox(
810
- label="🆔 ID do Job para Download",
811
- placeholder="Cole o ID do job concluído...", )
 
812
 
813
- download_button = gr.Button("📦 Preparar Download", variant="primary")
814
- download_file = gr.File(label="📁 Arquivo para Download")
815
- download_status = gr.Textbox(label="📊 Status do Download", lines=3)
816
 
817
- download_button.click(
818
- download_model_wrapper,
819
- inputs=download_job_id,
820
- outputs=[download_file, download_status]
821
- )
822
-
823
- # Aba de Informações
824
- with gr.TabItem("ℹ️ Sobre"):
825
- gr.Markdown("""
826
- ### 🎯 Sobre o LoRA Image Trainer
827
-
828
- Esta ferramenta foi desenvolvida para democratizar o acesso ao treinamento de modelos LoRA para geração de imagens,
829
- permitindo que qualquer pessoa possa criar adaptações personalizadas de modelos de difusão (como Stable Diffusion)
830
- sem a necessidade de hardware especializado.
831
-
832
- #### ✨ Características Principais:
833
-
834
- - **🔋 Otimizado para Baixa GPU**: Utiliza técnicas como mixed precision, gradient checkpointing e configurações otimizadas
835
- - **📱 Compatível com Móveis**: Interface responsiva que funciona em smartphones e tablets
836
- - **⚡ Rápido e Eficiente**: Treinamento otimizado com bibliotecas Diffusers e PEFT do Hugging Face
837
- - **🎛️ Configurável**: Controle total sobre parâmetros LoRA e de treinamento
838
- - **☁️ Pronto para Deploy**: Facilmente implantável no Hugging Face Spaces
839
- - **🎨 Focado em Imagens**: Especificamente projetado para modelos de difusão e geração de imagens
840
-
841
- #### 🛠️ Tecnologias Utilizadas:
842
-
843
- - **Hugging Face Diffusers**: Para modelos de difusão e pipeline de treinamento
844
- - **PEFT (Parameter-Efficient Fine-Tuning)**: Para treinamento eficiente de LoRA
845
- - **PyTorch**: Framework de deep learning
846
- - **Gradio**: Interface web interativa e responsiva
847
- - **LoRA (Low-Rank Adaptation)**: Técnica de fine-tuning eficiente para modelos de difusão
848
-
849
- #### 📖 Como Usar:
850
-
851
- 1. **Prepare suas imagens**: Colete 3-50 imagens de alta qualidade do estilo/conceito que deseja treinar
852
- 2. **Escolha um modelo base** na aba "Treinar LoRA" (recomendado: Stable Diffusion 1.5)
853
- 3. **Faça upload das imagens** e defina uma trigger word (palavra-chave)
854
- 4. **Configure os parâmetros** conforme necessário (valores padrão funcionam bem)
855
- 5. **Inicie o treinamento** e anote o ID do job
856
- 6. **Acompanhe o progresso** na aba "Status do Treinamento"
857
- 7. **Baixe seu LoRA** na aba "Modelos e Download" quando concluído
858
- 8. **Use em suas ferramentas favoritas** (ComfyUI, Automatic1111, etc.)
859
-
860
- #### 💡 Dicas para Melhores Resultados:
861
-
862
- - **Qualidade > Quantidade**: 10-20 imagens de alta qualidade são melhores que 50 imagens ruins
863
- - **Consistência**: Use imagens com estilo/conceito consistente
864
- - **Resolução**: Para GPUs com pouca VRAM, use resolução 512x512
865
- - **Trigger Word**: Escolha uma palavra única e fácil de lembrar
866
- - **Legendas**: Descreva o que há nas imagens para melhor controle
867
- - **Parâmetros**: Para iniciantes, use os valores padrão
868
-
869
- #### 🎮 Compatibilidade:
870
-
871
- Os LoRAs gerados são compatíveis com:
872
- - **ComfyUI**: Carregue os arquivos .safetensors
873
- - **Automatic1111**: Coloque na pasta models/Lora
874
- - **SeaArt**: Faça upload do modelo
875
- - **Outras ferramentas**: Qualquer ferramenta que suporte LoRA para Stable Diffusion
876
-
877
- ---
878
-
879
- **Desenvolvido com ❤️ para a comunidade de IA e arte digital**
880
- """)
881
-
882
- # Footer
883
- gr.Markdown("""
884
- ---
885
- <div style="text-align: center; color: #666; font-size: 0.9em;">
886
- 🎨 LoRA Image Trainer v1.0 | Otimizado para Baixa GPU | Compatível com Dispositivos Móveis
887
- </div>
888
- """)
889
 
890
  return interface
891
 
@@ -893,15 +922,15 @@ def create_gradio_interface():
893
  if __name__ == "__main__":
894
  # Criar diretórios necessários
895
  os.makedirs("./lora_models", exist_ok=True)
896
-
897
- # Configurar interface
898
- interface = create_gradio_interface()
899
-
900
- # Lançar aplicação
901
- interface.launch(
902
- server_name="0.0.0.0",
903
- server_port=7860,
904
- share=False,
905
- show_error=True,
906
- quiet=False
907
- )
 
30
 
31
  class LoRAImageTrainer:
32
  """Classe principal para treinamento de modelos LoRA para geração de imagens otimizada para baixo uso de GPU."""
33
+
34
+ def __init__(self):
35
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ self.training_jobs = {}
37
+ self.models_cache = {}
38
+
39
+ def get_available_models(self) -> List[str]:
40
+ """Retorna lista de modelos base disponíveis para treinamento LoRA."""
41
+ return [
42
+ "runwayml/stable-diffusion-v1-5",
43
+ "stabilityai/stable-diffusion-2-1",
44
+ "CompVis/stable-diffusion-v1-4"
45
+ # XL removido por ser pesado demais para Spaces gratuitos
46
+ ]
47
+
48
+ def load_base_model(self, model_name: str):
49
+ """Carrega modelo base de difusão com otimizações para baixo uso de GPU."""
50
+ try:
51
+ if model_name in self.models_cache:
52
+ return self.models_cache[model_name]
53
+
54
+ logger.info(f"Carregando modelo base: {model_name}")
55
+
56
+ # Configurações para otimização de memória
57
+ model_kwargs = {
58
+ "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
59
+ "use_safetensors": True,
60
+ "variant": "fp16" if torch.cuda.is_available() else None,
61
+ "safety_checker": None, # Desativa verificador de segurança para economizar memória
62
+ }
63
+
64
+ # Carregar pipeline completo
65
+ pipeline = StableDiffusionPipeline.from_pretrained(
66
+ model_name,
67
+ **model_kwargs
68
+ )
69
+
70
+ if torch.cuda.is_available():
71
+ pipeline = pipeline.to(self.device)
72
+ # Habilitar attention slicing para economia de memória
73
+ pipeline.enable_attention_slicing()
74
+ # Habilitar memory efficient attention se disponível
75
+ try:
76
+ pipeline.enable_xformers_memory_efficient_attention()
77
+ except Exception as e:
78
+ logger.warning("xformers não disponível, usando attention padrão")
79
+
80
+ # Cache do modelo
81
+ self.models_cache[model_name] = pipeline
82
+
83
+ return pipeline
84
+
85
+ except Exception as e:
86
+ logger.error(f"Erro ao carregar modelo {model_name}: {str(e)}")
87
+ raise e
88
+
89
+ def create_lora_config(self,
90
+ r: int = 16,
91
+ lora_alpha: int = 32,
92
+ lora_dropout: float = 0.1,
93
+ target_modules: Optional[List[str]] = None) -> LoraConfig:
94
+ """Cria configuração LoRA otimizada para modelos de difusão."""
95
+
96
+ if target_modules is None:
97
+ # Módulos padrão para UNet do Stable Diffusion
98
+ target_modules = [
99
+ "to_k", "to_q", "to_v", "to_out.0",
100
+ "proj_in", "proj_out",
101
+ "ff.net.0.proj", "ff.net.2"
102
+ ]
103
+
104
+ return LoraConfig(
105
+ r=r,
106
+ lora_alpha=lora_alpha,
107
+ target_modules=target_modules,
108
+ lora_dropout=lora_dropout,
109
+ bias="none",
110
+ task_type=TaskType.DIFFUSION,
111
+ )
112
+
113
+ def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:
114
+ """Prepara dataset de imagens para treinamento."""
115
+ dataset = []
116
+
117
+ for img_path, caption in zip(image_files, captions):
118
+ try:
119
+ # Carregar e redimensionar imagem
120
+ image = Image.open(img_path).convert("RGB")
121
+
122
+ # Redimensionar mantendo aspect ratio
123
+ image = self.resize_image(image, resolution)
124
+
125
+ dataset.append({
126
+ "image": image,
127
+ "caption": caption,
128
+ "image_path": img_path
129
+ })
130
+
131
+ except Exception as e:
132
+ logger.error(f"Erro ao processar imagem {img_path}: {str(e)}")
133
+ continue
134
+
135
+ return dataset
136
+
137
+ def resize_image(self, image: Image.Image, target_size: int) -> Image.Image:
138
+ """Redimensiona imagem mantendo aspect ratio e fazendo crop central se necessário."""
139
+ width, height = image.size
140
+
141
+ # Calcular novo tamanho mantendo aspect ratio
142
+ if width > height:
143
+ new_width = target_size
144
+ new_height = int((height * target_size) / width)
145
+ else:
146
+ new_height = target_size
147
+ new_width = int((width * target_size) / height)
148
+
149
+ # Redimensionar
150
+ image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
151
+
152
+ # Crop central para obter tamanho exato
153
+ if new_width != target_size or new_height != target_size:
154
+ left = (new_width - target_size) // 2
155
+ top = (new_height - target_size) // 2
156
+ right = left + target_size
157
+ bottom = top + target_size
158
+
159
+ image = image.crop((left, top, right, bottom))
160
+
161
  return image
162
+
163
+ def real_lora_training(self,
164
  job_id: str,
165
  model_name: str,
166
  dataset: List[Dict],
 
171
  learning_rate: float = 1e-4,
172
  batch_size: int = 1,
173
  resolution: int = 512) -> None:
174
+ """TREINAMENTO REAL DE LoRA PARA IMAGENS."""
175
 
176
  try:
177
  # Atualizar status
178
  self.training_jobs[job_id]["status"] = "loading_model"
179
  self.training_jobs[job_id]["progress"] = 5
180
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Carregando modelo base: {model_name}")
181
+
182
+ # Carregar modelo base
183
+ pipeline = self.load_base_model(model_name)
184
+ unet = pipeline.unet
185
+ text_encoder = pipeline.text_encoder
186
+ vae = pipeline.vae
187
+ tokenizer = pipeline.tokenizer
188
+ scheduler = pipeline.scheduler
189
+
190
+ # Congelar parâmetros
191
+ unet.requires_grad_(False)
192
+ text_encoder.requires_grad_(False)
193
+ vae.requires_grad_(False)
194
+
195
+ # Configurar LoRA no UNet
196
  lora_config = self.create_lora_config(r, lora_alpha, lora_dropout)
197
+ unet_lora = get_peft_model(unet, lora_config)
198
+ unet_lora.train()
199
+ unet_lora.to(self.device)
200
+
201
+ # Otimizador
202
+ optimizer = torch.optim.AdamW(unet_lora.parameters(), lr=learning_rate)
203
+
204
+ # Preparar scheduler para treinamento
205
  self.training_jobs[job_id]["status"] = "preparing_data"
206
+ self.training_jobs[job_id]["progress"] = 20
207
+
208
+ # Normalização de imagem
209
+ def preprocess_image(image):
210
+ image = np.array(image).astype(np.float32) / 255.0
211
+ image = image.transpose(2, 0, 1)
212
+ image = torch.from_numpy(image).unsqueeze(0)
213
+ return image
214
+
215
+ # Loop de treinamento real
216
  total_steps = num_epochs * len(dataset)
217
  current_step = 0
218
+
219
+ self.training_jobs[job_id]["status"] = "training"
220
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Iniciando treinamento real...")
221
+
222
  for epoch in range(num_epochs):
223
+ for item in dataset:
224
  current_step += 1
225
 
226
+ # Obter imagem e legenda
227
+ image = item["image"]
228
+ caption = item["caption"]
229
+
230
+ # Pré-processar imagem
231
+ image_tensor = preprocess_image(image).to(self.device)
232
+ if torch.cuda.is_available():
233
+ image_tensor = image_tensor.half()
234
+
235
+ # Codificar imagem para latentes
236
+ with torch.no_grad():
237
+ latents = vae.encode(image_tensor * 2 - 1).latent_dist.sample() * 0.18215
238
+
239
+ # Tokenizar texto
240
+ inputs = tokenizer(caption, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
241
+ input_ids = inputs.input_ids.to(self.device)
242
+
243
+ # Gerar timesteps aleatórios
244
+ timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (1,), device=self.device).long()
245
+
246
+ # Adicionar ruído aos latentes
247
+ noise = torch.randn_like(latents)
248
+ noisy_latents = scheduler.add_noise(latents, noise, timesteps)
249
+
250
+ # Forward pass
251
+ encoder_hidden_states = text_encoder(input_ids)[0]
252
+ noise_pred = unet_lora(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
253
+
254
+ # Calcular perda
255
+ loss = torch.nn.functional.mse_loss(noise_pred, noise)
256
+
257
+ # Backward pass
258
+ optimizer.zero_grad()
259
+ loss.backward()
260
+ optimizer.step()
261
+
262
  # Atualizar progresso
263
  progress = 30 + int((current_step / total_steps) * 60)
264
  self.training_jobs[job_id]["progress"] = min(progress, 90)
265
+
266
+ if current_step % max(1, len(dataset)//2) == 0:
267
+ log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
268
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - {log_msg}")
269
+
270
+ # Salvar LoRA
 
 
 
271
  self.training_jobs[job_id]["status"] = "saving"
272
  self.training_jobs[job_id]["progress"] = 95
273
+
 
274
  output_dir = f"./lora_models/{job_id}"
275
  os.makedirs(output_dir, exist_ok=True)
276
+
277
+ # Salvar apenas os pesos LoRA do UNet
278
+ unet_lora.save_pretrained(output_dir)
279
+
280
+ # Criar adapter_config.json
281
  lora_config_dict = {
282
  "r": r,
283
  "lora_alpha": lora_alpha,
284
+ "target_modules": lora_config.target_modules,
285
  "lora_dropout": lora_dropout,
286
  "bias": "none",
287
  "task_type": "DIFFUSION",
 
294
  "num_images": len(dataset)
295
  }
296
  }
 
297
  with open(f"{output_dir}/adapter_config.json", "w") as f:
298
  json.dump(lora_config_dict, f, indent=2)
299
+
300
+ # README
 
 
 
 
301
  readme_content = f"""# LoRA Model - {job_id}
302
 
303
+ Informações do Treinamento
304
 
305
+ Modelo Base: {model_name}
306
+ Rank (r): {r}
307
+ LoRA Alpha: {lora_alpha}
308
+ Dropout: {lora_dropout}
309
+ Épocas: {num_epochs}
310
+ Taxa de Aprendizado: {learning_rate}
311
+ Resolução: {resolution}x{resolution}
312
+ Número de Imagens: {len(dataset)}
313
+ Data de Treinamento: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
314
 
315
+ Como Usar
316
 
317
+ 1. Baixe os arquivos adapter_config.json e adapter_model.safetensors
318
  2. Carregue em sua ferramenta de geração de imagens favorita (ComfyUI, Automatic1111, etc.)
319
  3. Use o trigger word ou estilo aprendido durante o treinamento
 
 
 
 
 
 
320
  """
 
321
  with open(f"{output_dir}/README.md", "w") as f:
322
  f.write(readme_content)
323
+
324
  # Finalizar
325
  self.training_jobs[job_id]["status"] = "completed"
326
  self.training_jobs[job_id]["progress"] = 100
327
  self.training_jobs[job_id]["model_path"] = output_dir
328
  self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
329
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Treinamento REAL concluído! LoRA salvo em {output_dir}")
330
+
331
+ logger.info(f"Treinamento LoRA REAL concluído para job {job_id}")
332
+
333
  except Exception as e:
334
+ error_msg = f"Erro REAL no treinamento: {str(e)}"
335
+ logger.error(error_msg)
336
  self.training_jobs[job_id]["status"] = "error"
337
+ self.training_jobs[job_id]["error"] = error_msg
338
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ❌ {error_msg}")
339
+
340
+ def start_training(self,
341
  model_name: str,
342
  image_files: List[str],
343
  captions: List[str],
344
  **kwargs) -> str:
345
  """Inicia treinamento LoRA assíncrono."""
346
+
347
+ job_id = str(uuid.uuid4())
348
+
349
+ # Preparar dataset
350
+ dataset = self.prepare_image_dataset(image_files, captions, kwargs.get('resolution', 512))
351
+
352
+ self.training_jobs[job_id] = {
353
+ "id": job_id,
354
+ "status": "queued",
355
+ "progress": 0,
356
+ "created_at": datetime.now().isoformat(),
357
+ "model_name": model_name,
358
+ "num_images": len(dataset),
359
+ "logs": [],
360
+ "error": None,
361
+ "model_path": None,
362
+ "completed_at": None
363
+ }
364
+
365
+ # Iniciar treinamento em thread separada
366
+ thread = threading.Thread(
367
+ target=self.real_lora_training,
368
+ args=(job_id, model_name, dataset),
369
+ kwargs=kwargs
370
+ )
371
+ thread.daemon = True
372
+ thread.start()
373
+
374
  return job_id
375
+
376
  def get_training_status(self, job_id: str) -> Dict[str, Any]:
377
  """Retorna status do treinamento."""
378
  return self.training_jobs.get(job_id, {"error": "Job não encontrado"})
379
+
380
  def list_trained_models(self) -> List[Dict[str, str]]:
381
  """Lista modelos LoRA treinados."""
382
  models = []
383
  lora_models_dir = Path("./lora_models")
384
+
385
+ if lora_models_dir.exists():
386
+ for model_dir in lora_models_dir.iterdir():
387
+ if model_dir.is_dir():
388
+ config_file = model_dir / "adapter_config.json"
389
+ if config_file.exists():
390
+ try:
391
+ with open(config_file, 'r') as f:
392
+ config = json.load(f)
393
+
394
+ models.append({
395
+ "id": model_dir.name,
396
+ "path": str(model_dir),
397
+ "base_model": config.get("base_model_name", "Unknown"),
398
+ "r": config.get("r", "Unknown"),
399
+ "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
400
+ })
401
+ except Exception as e:
402
+ models.append({
403
+ "id": model_dir.name,
404
+ "path": str(model_dir),
405
+ "base_model": "Unknown",
406
+ "r": "Unknown",
407
+ "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
408
+ })
409
+
410
  return models
411
+
412
  def create_download_zip(self, model_path: str) -> str:
413
  """Cria um arquivo ZIP com os arquivos do modelo LoRA para download."""
414
  zip_path = f"{model_path}.zip"
415
+
416
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
417
+ model_dir = Path(model_path)
418
+ for file_path in model_dir.rglob('*'):
419
+ if file_path.is_file():
420
+ arcname = file_path.relative_to(model_dir)
421
+ zipf.write(file_path, arcname)
422
+
423
  return zip_path
424
 
425
+
426
  # Instância global do trainer
427
  trainer = LoRAImageTrainer()
428
 
429
  def create_gradio_interface():
430
  """Cria interface Gradio para a ferramenta LoRA de geração de imagens."""
431
+
432
+ # CSS personalizado para responsividade móvel
433
+ custom_css = """
434
+ /* Mobile-first responsive design */
435
+ @media (max-width: 768px) {
436
+ .gradio-container {
437
+ padding: 8px !important;
438
+ margin: 0 !important;
439
+ }
440
+
441
+ .tab-nav {
442
+ flex-wrap: wrap !important;
443
+ gap: 4px !important;
444
+ }
445
+
446
+ .tab-nav button {
447
+ font-size: 14px !important;
448
+ padding: 8px 12px !important;
449
+ min-width: auto !important;
450
+ flex: 1 1 auto !important;
451
+ }
452
+
453
+ .form-container {
454
+ padding: 12px !important;
455
+ }
456
+
457
+ .btn {
458
+ width: 100% !important;
459
+ padding: 12px !important;
460
+ font-size: 16px !important;
461
+ margin-bottom: 8px !important;
462
+ min-height: 44px !important;
463
+ }
464
+
465
+ .textbox textarea {
466
+ font-size: 16px !important;
467
+ min-height: 120px !important;
468
+ }
469
+
470
+ .dropdown select {
471
+ font-size: 16px !important;
472
+ padding: 12px !important;
473
+ }
474
+
475
+ .output-text {
476
+ font-size: 14px !important;
477
+ line-height: 1.5 !important;
478
+ }
479
+
480
+ .column {
481
+ margin-bottom: 16px !important;
482
+ }
483
+
484
+ .file-upload {
485
+ min-height: 100px !important;
486
+ }
487
+ }
488
+
489
+ /* Enhanced visual styles */
490
+ .lora-header {
491
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
492
+ color: white;
493
+ padding: 20px;
494
+ border-radius: 12px;
495
+ margin-bottom: 20px;
496
+ text-align: center;
497
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
498
+ }
499
+
500
+ .status-indicator {
501
+ display: inline-block;
502
+ padding: 4px 8px;
503
+ border-radius: 6px;
504
+ font-size: 12px;
505
+ font-weight: 600;
506
+ text-transform: uppercase;
507
+ letter-spacing: 0.5px;
508
+ margin-right: 8px;
509
+ }
510
+
511
+ .status-queued { background-color: #fbbf24; color: #92400e; }
512
+ .status-loading_model { background-color: #60a5fa; color: #1e40af; }
513
+ .status-preparing_lora { background-color: #8b5cf6; color: #5b21b6; }
514
+ .status-preparing_data { background-color: #06b6d4; color: #0e7490; }
515
+ .status-training { background-color: #a78bfa; color: #5b21b6; }
516
+ .status-saving { background-color: #f59e0b; color: #92400e; }
517
+ .status-completed { background-color: #34d399; color: #065f46; }
518
+ .status-error { background-color: #f87171; color: #991b1b; }
519
+
520
+ /* Touch device optimizations */
521
+ @media (hover: none) and (pointer: coarse) {
522
+ .btn {
523
+ min-height: 44px !important;
524
+ min-width: 44px !important;
525
+ }
526
+
527
+ .tab-nav button {
528
+ min-height: 44px !important;
529
+ min-width: 44px !important;
530
+ }
531
+ }
532
+ """
533
 
534
+ def process_images_and_captions(files, captions_text):
535
+ """Processa imagens e legendas enviadas pelo usuário."""
536
+ if not files:
537
+ return "❌ Erro: Nenhuma imagem foi enviada!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
 
539
+ # Processar legendas
540
+ captions = []
541
+ if captions_text.strip():
542
+ captions = [line.strip() for line in captions_text.split('\n') if line.strip()]
543
 
544
+ # Se não há legendas suficientes, usar legendas padrão
545
+ while len(captions) < len(files):
546
+ captions.append(f"training image {len(captions) + 1}")
547
 
548
+ # Truncar legendas se houver mais que imagens
549
+ captions = captions[:len(files)]
550
 
551
  return files, captions
552
 
553
+ def start_training_wrapper(model_name, files, captions_text, trigger_word, r, lora_alpha, lora_dropout,
554
+ num_epochs, learning_rate, batch_size, resolution):
555
+ """Wrapper para iniciar treinamento via Gradio."""
556
+
557
+ if not files:
558
+ return "❌ Erro: Nenhuma imagem foi enviada para treinamento!"
559
+
560
+ if len(files) < 3:
561
+ return "❌ Erro: Forneça pelo menos 3 imagens para treinamento!"
562
+
563
+ try:
564
+ # Processar imagens e legendas
565
+ image_files = [f.name for f in files]
566
+
567
+ # Processar legendas
568
+ captions = []
569
+ if captions_text.strip():
570
+ captions = [line.strip() for line in captions_text.split('\n') if line.strip()]
571
+
572
+ # Se não há legendas suficientes, usar trigger word + descrição padrão
573
+ while len(captions) < len(files):
574
+ if trigger_word.strip():
575
+ captions.append(f"{trigger_word.strip()}, high quality photo")
576
+ else:
577
+ captions.append(f"training image {len(captions) + 1}, high quality photo")
578
+
579
+ # Truncar legendas se houver mais que imagens
580
+ captions = captions[:len(files)]
581
+
582
+ job_id = trainer.start_training(
583
+ model_name=model_name,
584
+ image_files=image_files,
585
+ captions=captions,
586
+ r=int(r),
587
+ lora_alpha=int(lora_alpha),
588
+ lora_dropout=float(lora_dropout),
589
+ num_epochs=int(num_epochs),
590
+ learning_rate=float(learning_rate),
591
+ batch_size=int(batch_size),
592
+ resolution=int(resolution)
593
+ )
594
+
595
+ return f"✅ Treinamento REAL iniciado! ID do Job: {job_id}\n\n📊 Imagens: {len(files)}\n🏷️ Trigger Word: {trigger_word or 'Nenhuma'}\n\nUse o ID acima para verificar o progresso na aba 'Status do Treinamento'."
596
+
597
+ except Exception as e:
598
+ return f"❌ Erro ao iniciar treinamento: {str(e)}"
599
 
600
+ def check_status_wrapper(job_id):
601
+ """Wrapper para verificar status via Gradio."""
602
+ if not job_id.strip():
603
+ return "❌ Erro: Forneça um ID de job válido!"
604
+
605
+ status = trainer.get_training_status(job_id.strip())
606
+
607
+ if "error" in status and status["error"] == "Job não encontrado":
608
+ return "❌ Job não encontrado! Verifique o ID."
609
+
610
+ # Criar indicador visual de status
611
+ status_class = f"status-{status['status']}"
612
+ status_emoji = {
613
+ 'queued': '⏳',
614
+ 'loading_model': '📥',
615
+ 'preparing_lora': '⚙️',
616
+ 'preparing_data': '📊',
617
+ 'training': '🏋️',
618
+ 'saving': '💾',
619
+ 'completed': '✅',
620
+ 'error': '❌'
621
+ }.get(status['status'], '📊')
622
+
623
+ # Barra de progresso visual
624
+ progress = status['progress']
625
+ progress_bar = f"""
626
+ <div style="width: 100%; background-color: #e5e7eb; border-radius: 4px; overflow: hidden; margin: 8px 0;">
627
+ <div style="width: {progress}%; height: 8px; background: linear-gradient(90deg, #3b82f6, #8b5cf6); transition: width 0.3s ease; border-radius: 4px;"></div>
628
+ </div>
629
+ """
630
 
631
  status_text = f"""
 
632
 
633
+ 📊 Status do Treinamento LoRA
634
+
635
+ 🆔 Job ID: {status['id']}
636
+ {status_emoji} Status: <span class="{status_class}">{status['status'].upper().replace('_', ' ')}</span>
637
+ ⏳ Progresso: {status['progress']}%
638
 
639
  {progress_bar}
640
 
641
+ 🤖 Modelo Base: {status['model_name']}
642
+ 🖼️ Imagens: {status.get('num_images', 'N/A')}
643
+ 📅 Criado em: {status['created_at']}
644
 
645
  """
646
+
647
+ if status['logs']:
648
+ status_text += "📝 **Logs Recentes:**\n"
649
+ for log in status['logs'][-5:]: # Últimos 5 logs
650
+ status_text += f"• {log}\n"
651
+
652
+ if status['status'] == 'completed':
653
+ status_text += f"\n✅ **Treinamento Concluído!**\n📁 **Modelo salvo em:** {status['model_path']}"
654
+ status_text += f"\n⏰ **Concluído em:** {status['completed_at']}"
655
+ status_text += f"\n\n💡 **Próximos passos:** Vá para a aba 'Modelos Treinados' para baixar seu LoRA!"
656
+ elif status['status'] == 'error':
657
+ status_text += f"\n❌ **Erro:** {status['error']}"
658
 
659
  return status_text
660
 
661
+ def list_models_wrapper():
662
+ """Wrapper para listar modelos via Gradio."""
663
+ models = trainer.list_trained_models()
664
+
665
+ if not models:
666
+ return "📭 Nenhum modelo LoRA treinado encontrado."
667
+
668
+ models_text = "📚 **Modelos LoRA Treinados:**\n\n"
669
+ for model in models:
670
+ models_text += f"🆔 **ID:** {model['id']}\n"
671
+ models_text += f"🤖 **Modelo Base:** {model['base_model']}\n"
672
+ models_text += f"📊 **Rank (r):** {model['r']}\n"
673
+ models_text += f"📁 **Caminho:** {model['path']}\n"
674
+ models_text += f"📅 **Criado:** {model['created']}\n\n"
675
+ models_text += "---\n\n"
676
 
677
  return models_text
678
 
679
+ def download_model_wrapper(job_id):
680
+ """Wrapper para preparar download do modelo."""
681
+ if not job_id.strip():
682
+ return None, "❌ Erro: Forneça um ID de job válido!"
683
 
684
+ status = trainer.get_training_status(job_id.strip())
685
 
686
+ if "error" in status and status["error"] == "Job não encontrado":
687
+ return None, "❌ Job não encontrado! Verifique o ID."
688
 
689
+ if status['status'] != 'completed':
690
+ return None, f"❌ Treinamento ainda não foi concluído. Status atual: {status['status']}"
691
 
692
+ try:
693
+ model_path = status['model_path']
694
+ zip_path = trainer.create_download_zip(model_path)
695
 
696
+ return zip_path, f"✅ Arquivo ZIP criado com sucesso! Clique no link acima para baixar."
697
 
698
+ except Exception as e:
699
+ return None, f"❌ Erro ao criar arquivo de download: {str(e)}"
700
 
701
+ # Interface Gradio
702
+ with gr.Blocks(
703
+ title="🎨 LoRA Image Trainer - Criador e Treinador de LoRA para Imagens",
704
+ theme=gr.themes.Soft(),
705
+ css=custom_css
706
+ ) as interface:
707
+
708
+ gr.HTML("""
709
+ <div class="lora-header">
710
+ <h1>🎨 LoRA Image Trainer</h1>
711
+ <p>Criador e Treinador de LoRA para Geração de Imagens</p>
712
+ <p style="font-size: 0.9em; opacity: 0.9; margin-top: 8px;">
713
+ Ferramenta otimizada para baixo uso de GPU, compatível com dispositivos móveis
714
+ </p>
715
+ </div>
716
+ """)
717
+
718
+ with gr.Tabs():
719
+
720
+ # Aba de Treinamento
721
+ with gr.TabItem("🎯 Treinar LoRA"):
722
+ gr.Markdown("### Configurar e Iniciar Treinamento LoRA para Imagens")
723
 
724
+ with gr.Row():
725
+ with gr.Column(scale=2):
726
+ model_dropdown = gr.Dropdown(
727
+ choices=trainer.get_available_models(),
728
+ value="runwayml/stable-diffusion-v1-5",
729
+ label="🤖 Modelo Base",
730
+ )
 
731
 
732
+ image_files = gr.File(
733
+ file_count="multiple",
734
+ file_types=["image"],
735
+ label="🖼️ Imagens de Treinamento",
736
+ )
 
737
 
738
+ trigger_word = gr.Textbox(
739
+ label="🏷️ Trigger Word (Opcional)",
740
+ placeholder="ex: meuEstilo, minhaPersonagem, etc.",
741
+ )
 
742
 
743
+ captions_text = gr.Textbox(
744
+ lines=8,
745
+ placeholder="Digite uma legenda por linha (opcional)...\n\nExemplo:\nmeuEstilo, retrato de uma mulher\nmeuEstilo, homem sorrindo\nmeuEstilo, paisagem urbana\n\nSe deixar vazio, usará a trigger word + 'high quality photo'",
746
+ label="📝 Legendas das Imagens (Opcional)",
747
+ )
 
748
 
749
+ with gr.Column(scale=1):
750
+ gr.Markdown("### ⚙️ Parâmetros LoRA")
751
 
752
+ r = gr.Slider(
753
+ minimum=4, maximum=64, value=8, step=4, # reduzido max para 64
754
+ label="r (Rank)",
755
+ )
 
756
 
757
+ lora_alpha = gr.Slider(
758
+ minimum=1, maximum=64, value=16, step=1, # reduzido max para 64
759
+ label="LoRA Alpha",
760
+ )
 
761
 
762
+ lora_dropout = gr.Slider(
763
+ minimum=0.0, maximum=0.5, value=0.0, step=0.05, # dropout 0 para mais estabilidade
764
+ label="LoRA Dropout",
765
+ )
 
766
 
767
+ gr.Markdown("### 🏋️ Parâmetros de Treinamento")
768
 
769
+ num_epochs = gr.Slider(
770
+ minimum=5, maximum=20, value=10, step=5, # reduzido max para 20
771
+ label="Épocas",
772
+ )
 
773
 
774
+ learning_rate = gr.Slider(
775
+ minimum=1e-5, maximum=5e-4, value=1e-4, step=1e-5, # reduzido max
776
+ label="Taxa de Aprendizado",
777
+ )
 
778
 
779
+ batch_size = gr.Slider(
780
+ minimum=1, maximum=1, value=1, step=1, # fixado em 1 para Spaces
781
+ label="Batch Size",
782
+ )
 
783
 
784
+ resolution = gr.Dropdown(
785
+ choices=[512], # fixado em 512 para garantir funcionamento em GPU limitada
786
+ value=512,
787
+ label="Resolução",
788
+ )
789
+
790
+ train_button = gr.Button("🚀 Iniciar Treinamento LoRA", variant="primary", size="lg")
791
+ train_output = gr.Textbox(label="📊 Resultado", lines=5)
 
792
 
793
+ train_button.click(
794
+ start_training_wrapper,
795
+ inputs=[model_dropdown, image_files, captions_text, trigger_word, r, lora_alpha, lora_dropout,
796
+ num_epochs, learning_rate, batch_size, resolution],
797
+ outputs=train_output
798
+ )
799
+
800
+ # Aba de Status
801
+ with gr.TabItem("📊 Status do Treinamento"):
802
+ gr.Markdown("### Verificar Progresso do Treinamento")
803
 
804
+ job_id_input = gr.Textbox(
805
+ label="🆔 ID do Job",
806
+ placeholder="Cole aqui o ID do job de treinamento...",
807
+ )
 
808
 
809
+ status_button = gr.Button("🔍 Verificar Status", variant="secondary")
810
+ status_output = gr.Textbox(label="📈 Status", lines=12)
811
 
812
+ status_button.click(
813
+ check_status_wrapper,
814
+ inputs=job_id_input,
815
+ outputs=status_output
816
+ )
817
 
818
+ gr.Markdown("💡 **Dica:** Atualize o status regularmente para acompanhar o progresso do treinamento.")
819
 
820
+ # Aba de Modelos e Download
821
+ with gr.TabItem("📚 Modelos e Download"):
822
+ gr.Markdown("### Visualizar e Baixar Modelos LoRA Treinados")
823
 
824
+ with gr.Row():
825
+ with gr.Column(scale=1):
826
+ list_button = gr.Button("📋 Listar Modelos", variant="secondary")
827
+ models_output = gr.Textbox(label="📚 Modelos Disponíveis", lines=10)
828
 
829
+ list_button.click(
830
+ list_models_wrapper,
831
+ outputs=models_output
832
+ )
833
 
834
+ with gr.Column(scale=1):
835
+ gr.Markdown("#### 💾 Download de Modelo")
836
 
837
+ download_job_id = gr.Textbox(
838
+ label="🆔 ID do Job para Download",
839
+ placeholder="Cole o ID do job concluído...",
840
+ )
841
 
842
+ download_button = gr.Button("📦 Preparar Download", variant="primary")
843
+ download_file = gr.File(label="📁 Arquivo para Download")
844
+ download_status = gr.Textbox(label="📊 Status do Download", lines=3)
845
 
846
+ download_button.click(
847
+ download_model_wrapper,
848
+ inputs=download_job_id,
849
+ outputs=[download_file, download_status]
850
+ )
851
+
852
+ # Aba de Informa��ões
853
+ with gr.TabItem("ℹ️ Sobre"):
854
+ gr.Markdown("""
855
+ ### 🎯 Sobre o LoRA Image Trainer
856
+
857
+ Esta ferramenta foi desenvolvida para democratizar o acesso ao treinamento de modelos LoRA para geração de imagens,
858
+ permitindo que qualquer pessoa possa criar adaptações personalizadas de modelos de difusão (como Stable Diffusion)
859
+ sem a necessidade de hardware especializado.
860
+
861
+ #### ✨ Características Principais:
862
+
863
+ - **🔋 Otimizado para Baixa GPU**: Utiliza técnicas como mixed precision, gradient checkpointing e configurações otimizadas
864
+ - **📱 Compatível com Móveis**: Interface responsiva que funciona em smartphones e tablets
865
+ - **⚡ Rápido e Eficiente**: Treinamento otimizado com bibliotecas Diffusers e PEFT do Hugging Face
866
+ - **🎛️ Configurável**: Controle total sobre parâmetros LoRA e de treinamento
867
+ - **☁️ Pronto para Deploy**: Facilmente implantável no Hugging Face Spaces
868
+ - **🎨 Focado em Imagens**: Especificamente projetado para modelos de difusão e geração de imagens
869
+
870
+ #### 🛠️ Tecnologias Utilizadas:
871
+
872
+ - **Hugging Face Diffusers**: Para modelos de difusão e pipeline de treinamento
873
+ - **PEFT (Parameter-Efficient Fine-Tuning)**: Para treinamento eficiente de LoRA
874
+ - **PyTorch**: Framework de deep learning
875
+ - **Gradio**: Interface web interativa e responsiva
876
+ - **LoRA (Low-Rank Adaptation)**: Técnica de fine-tuning eficiente para modelos de difusão
877
+
878
+ #### 📖 Como Usar:
879
+
880
+ 1. **Prepare suas imagens**: Colete 3-50 imagens de alta qualidade do estilo/conceito que deseja treinar
881
+ 2. **Escolha um modelo base** na aba "Treinar LoRA" (recomendado: Stable Diffusion 1.5)
882
+ 3. **Faça upload das imagens** e defina uma trigger word (palavra-chave)
883
+ 4. **Configure os parâmetros** conforme necessário (valores padrão funcionam bem)
884
+ 5. **Inicie o treinamento** e anote o ID do job
885
+ 6. **Acompanhe o progresso** na aba "Status do Treinamento"
886
+ 7. **Baixe seu LoRA** na aba "Modelos e Download" quando concluído
887
+ 8. **Use em suas ferramentas favoritas** (ComfyUI, Automatic1111, etc.)
888
+
889
+ #### 💡 Dicas para Melhores Resultados:
890
+
891
+ - **Qualidade > Quantidade**: 10-20 imagens de alta qualidade são melhores que 50 imagens ruins
892
+ - **Consistência**: Use imagens com estilo/conceito consistente
893
+ - **Resolução**: Para GPUs com pouca VRAM, use resolução 512x512
894
+ - **Trigger Word**: Escolha uma palavra única e fácil de lembrar
895
+ - **Legendas**: Descreva o que há nas imagens para melhor controle
896
+ - **Parâmetros**: Para iniciantes, use os valores padrão
897
+
898
+ #### 🎮 Compatibilidade:
899
+
900
+ Os LoRAs gerados são compatíveis com:
901
+ - **ComfyUI**: Carregue os arquivos .safetensors
902
+ - **Automatic1111**: Coloque na pasta models/Lora
903
+ - **SeaArt**: Faça upload do modelo
904
+ - **Outras ferramentas**: Qualquer ferramenta que suporte LoRA para Stable Diffusion
905
+
906
+ ---
907
+
908
+ **Desenvolvido com ❤️ para a comunidade de IA e arte digital**
909
+ """)
910
+
911
+ # Footer
912
+ gr.Markdown("""
913
+ ---
914
+ <div style="text-align: center; color: #666; font-size: 0.9em;">
915
+ 🎨 LoRA Image Trainer v1.0 | Otimizado para Baixa GPU | Compatível com Dispositivos Móveis
916
+ </div>
917
+ """)
918
 
919
  return interface
920
 
 
922
  if __name__ == "__main__":
923
  # Criar diretórios necessários
924
  os.makedirs("./lora_models", exist_ok=True)
925
+
926
+ # Configurar interface
927
+ interface = create_gradio_interface()
928
+
929
+ # Lançar aplicação
930
+ interface.launch(
931
+ server_name="0.0.0.0",
932
+ server_port=7860,
933
+ share=False,
934
+ show_error=True,
935
+ quiet=False
936
+ )