Allex21 commited on
Commit
10a7187
·
verified ·
1 Parent(s): c4646a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +555 -184
app.py CHANGED
@@ -35,68 +35,33 @@ class LoRAImageTrainer:
35
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  self.training_jobs = {}
37
  self.models_cache = {}
38
- # ✅ Garantir que as pastas existam no diretório atual
39
- os.makedirs("./lora_models", exist_ok=True)
40
- os.makedirs("./jobs", exist_ok=True)
41
- logger.info("Pastas ./lora_models e ./jobs criadas com sucesso.")
42
 
43
- def _save_job_state(self, job_id: str):
44
- """Salva o estado do job em disco."""
45
- job_file = Path(f"./jobs/{job_id}.json")
46
- try:
47
- with open(job_file, "w") as f:
48
- json.dump(self.training_jobs[job_id], f, indent=2, default=str)
49
- logger.info(f"Estado do job {job_id} salvo em disco.")
50
- except Exception as e:
51
- logger.error(f"Erro ao salvar job {job_id}: {e}")
52
-
53
- def _load_job_state(self, job_id: str) -> Optional[Dict]:
54
- """Carrega o estado do job do disco."""
55
- job_file = Path(f"./jobs/{job_id}.json")
56
- if job_file.exists():
57
- try:
58
- with open(job_file, "r") as f:
59
- loaded_data = json.load(f)
60
- logger.info(f"Estado do job {job_id} carregado do disco.")
61
- return loaded_data
62
- except Exception as e:
63
- logger.error(f"Erro ao carregar job {job_id}: {e}")
64
- else:
65
- logger.warning(f"Arquivo do job {job_id} não encontrado em disco.")
66
- return None
67
-
68
- def get_training_status(self, job_id: str) -> Dict[str, Any]:
69
- """Retorna status do treinamento, carregando do disco se necessário."""
70
- if job_id in self.training_jobs:
71
- return self.training_jobs[job_id]
72
- else:
73
- loaded_job = self._load_job_state(job_id)
74
- if loaded_job:
75
- self.training_jobs[job_id] = loaded_job
76
- return loaded_job
77
- return {"error": "Job não encontrado"}
78
-
79
  def get_available_models(self) -> List[str]:
 
80
  return [
81
  "runwayml/stable-diffusion-v1-5",
82
  "stabilityai/stable-diffusion-2-1",
83
  "CompVis/stable-diffusion-v1-4"
 
84
  ]
85
 
86
  def load_base_model(self, model_name: str):
 
87
  try:
88
  if model_name in self.models_cache:
89
  return self.models_cache[model_name]
90
 
91
  logger.info(f"Carregando modelo base: {model_name}")
92
 
 
93
  model_kwargs = {
94
  "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
95
  "use_safetensors": True,
96
  "variant": "fp16" if torch.cuda.is_available() else None,
97
- "safety_checker": None,
98
  }
99
 
 
100
  pipeline = StableDiffusionPipeline.from_pretrained(
101
  model_name,
102
  **model_kwargs
@@ -104,14 +69,17 @@ class LoRAImageTrainer:
104
 
105
  if torch.cuda.is_available():
106
  pipeline = pipeline.to(self.device)
 
107
  pipeline.enable_attention_slicing()
 
108
  try:
109
  pipeline.enable_xformers_memory_efficient_attention()
110
- except:
111
- logger.warning("xformers não disponível")
112
- pipeline.unet.enable_gradient_checkpointing()
113
 
 
114
  self.models_cache[model_name] = pipeline
 
115
  return pipeline
116
 
117
  except Exception as e:
@@ -119,11 +87,15 @@ class LoRAImageTrainer:
119
  raise e
120
 
121
  def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:
 
122
  dataset = []
123
 
124
  for img_path, caption in zip(image_files, captions):
125
  try:
 
126
  image = Image.open(img_path).convert("RGB")
 
 
127
  image = self.resize_image(image, resolution)
128
 
129
  dataset.append({
@@ -139,8 +111,10 @@ class LoRAImageTrainer:
139
  return dataset
140
 
141
  def resize_image(self, image: Image.Image, target_size: int) -> Image.Image:
 
142
  width, height = image.size
143
 
 
144
  if width > height:
145
  new_width = target_size
146
  new_height = int((height * target_size) / width)
@@ -148,13 +122,16 @@ class LoRAImageTrainer:
148
  new_height = target_size
149
  new_width = int((width * target_size) / height)
150
 
 
151
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
152
 
 
153
  if new_width != target_size or new_height != target_size:
154
  left = (new_width - target_size) // 2
155
  top = (new_height - target_size) // 2
156
  right = left + target_size
157
  bottom = top + target_size
 
158
  image = image.crop((left, top, right, bottom))
159
 
160
  return image
@@ -163,35 +140,22 @@ class LoRAImageTrainer:
163
  job_id: str,
164
  model_name: str,
165
  dataset: List[Dict],
166
- r: int = 8,
167
- lora_alpha: int = 16,
168
- lora_dropout: float = 0.0,
169
- num_epochs: int = 5,
170
  learning_rate: float = 1e-4,
171
  batch_size: int = 1,
172
  resolution: int = 512) -> None:
 
 
173
  try:
174
- if job_id not in self.training_jobs:
175
- self.training_jobs[job_id] = {
176
- "id": job_id,
177
- "status": "queued",
178
- "progress": 0,
179
- "created_at": datetime.now().isoformat(),
180
- "model_name": model_name,
181
- "num_images": len(dataset),
182
- "logs": [],
183
- "error": None,
184
- "model_path": None,
185
- "completed_at": None
186
- }
187
-
188
  self.training_jobs[job_id]["status"] = "loading_model"
189
  self.training_jobs[job_id]["progress"] = 5
190
- log_msg = f"{datetime.now().strftime('%H:%M:%S')} - Carregando modelo base: {model_name}"
191
- self.training_jobs[job_id]["logs"].append(log_msg)
192
- self._save_job_state(job_id)
193
- logger.info(log_msg)
194
-
195
  pipeline = self.load_base_model(model_name)
196
  unet = pipeline.unet
197
  text_encoder = pipeline.text_encoder
@@ -199,15 +163,12 @@ class LoRAImageTrainer:
199
  tokenizer = pipeline.tokenizer
200
  scheduler = pipeline.scheduler
201
 
 
202
  unet.requires_grad_(False)
203
  text_encoder.requires_grad_(False)
204
  vae.requires_grad_(False)
205
 
206
- # Remover adaptador existente
207
- if hasattr(unet, "peft_config") and "default" in unet.peft_config:
208
- unet.delete_adapter("default")
209
- logger.info("Adaptador 'default' removido com sucesso.")
210
-
211
  lora_config = LoraConfig(
212
  r=r,
213
  lora_alpha=lora_alpha,
@@ -216,103 +177,101 @@ class LoRAImageTrainer:
216
  bias="none"
217
  )
218
 
 
219
  unet.add_adapter(lora_config, adapter_name="default")
220
- unet.set_adapter("default")
221
-
222
- # ✅ Ativar apenas parâmetros do LoRA
223
- unet.requires_grad_(False)
224
- trainable_params = 0
225
- for name, param in unet.named_parameters():
226
- if "lora_" in name:
227
- param.requires_grad = True
228
- trainable_params += 1
229
-
230
- logger.info(f"Número de parâmetros treináveis (LoRA): {trainable_params}")
231
 
 
 
232
  unet.train()
233
  unet.to(self.device)
234
 
235
- optimizer = torch.optim.AdamW([p for p in unet.parameters() if p.requires_grad], lr=learning_rate)
 
236
 
 
237
  self.training_jobs[job_id]["status"] = "preparing_data"
238
  self.training_jobs[job_id]["progress"] = 20
239
- self._save_job_state(job_id)
240
 
 
241
  def preprocess_image(image):
242
  image = np.array(image).astype(np.float32) / 255.0
243
  image = image.transpose(2, 0, 1)
244
  image = torch.from_numpy(image).unsqueeze(0)
245
  return image
246
 
 
247
  total_steps = num_epochs * len(dataset)
248
  current_step = 0
249
 
250
  self.training_jobs[job_id]["status"] = "training"
251
- log_msg = f"{datetime.now().strftime('%H:%M:%S')} - Iniciando treinamento real..."
252
- self.training_jobs[job_id]["logs"].append(log_msg)
253
- self._save_job_state(job_id)
254
- logger.info(log_msg)
255
 
256
  for epoch in range(num_epochs):
257
  for item in dataset:
258
  current_step += 1
259
 
 
260
  image = item["image"]
261
  caption = item["caption"]
 
 
262
  image_tensor = preprocess_image(image).to(self.device)
263
  if torch.cuda.is_available():
264
  image_tensor = image_tensor.half()
265
 
 
266
  with torch.no_grad():
267
  latents = vae.encode(image_tensor * 2 - 1).latent_dist.sample() * 0.18215
268
 
 
269
  inputs = tokenizer(caption, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
270
  input_ids = inputs.input_ids.to(self.device)
271
 
 
272
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (1,), device=self.device).long()
 
 
273
  noise = torch.randn_like(latents)
274
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
275
 
 
276
  encoder_hidden_states = text_encoder(input_ids)[0]
277
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
278
 
 
279
  loss = torch.nn.functional.mse_loss(noise_pred, noise)
 
 
280
  optimizer.zero_grad()
281
  loss.backward()
282
  optimizer.step()
283
 
284
- if torch.cuda.is_available():
285
- torch.cuda.empty_cache()
286
-
287
  progress = 30 + int((current_step / total_steps) * 60)
288
  self.training_jobs[job_id]["progress"] = min(progress, 90)
289
 
290
  if current_step % max(1, len(dataset)//2) == 0:
291
  log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
292
- self.training_jobs[job_id]["logs"].append(log_msg)
293
- self._save_job_state(job_id)
294
- logger.info(log_msg)
295
 
 
296
  self.training_jobs[job_id]["status"] = "saving"
297
  self.training_jobs[job_id]["progress"] = 95
298
- self._save_job_state(job_id)
299
 
300
  output_dir = f"./lora_models/{job_id}"
301
  os.makedirs(output_dir, exist_ok=True)
302
 
303
- unet.save_pretrained(
304
- output_dir,
305
- safe_serialization=True,
306
- selected_adapters=["default"]
307
- )
308
 
 
309
  lora_config_dict = {
310
  "r": r,
311
  "lora_alpha": lora_alpha,
312
  "target_modules": ["to_k", "to_q", "to_v", "to_out.0"],
313
  "lora_dropout": lora_dropout,
314
  "bias": "none",
315
- "task_type": "CAUSAL_LM",
316
  "base_model_name": model_name,
317
  "training_info": {
318
  "num_epochs": num_epochs,
@@ -325,39 +284,56 @@ class LoRAImageTrainer:
325
  with open(f"{output_dir}/adapter_config.json", "w") as f:
326
  json.dump(lora_config_dict, f, indent=2)
327
 
 
328
  readme_content = f"""# LoRA Model - {job_id}
329
- Treinado com sucesso!
 
 
330
  Modelo Base: {model_name}
331
- Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  """
333
  with open(f"{output_dir}/README.md", "w") as f:
334
  f.write(readme_content)
335
 
 
336
  self.training_jobs[job_id]["status"] = "completed"
337
  self.training_jobs[job_id]["progress"] = 100
338
  self.training_jobs[job_id]["model_path"] = output_dir
339
  self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
340
- log_msg = f"{datetime.now().strftime('%H:%M:%S')} - ✅ Treinamento concluído! LoRA salvo em {output_dir}"
341
- self.training_jobs[job_id]["logs"].append(log_msg)
342
- self._save_job_state(job_id)
343
- logger.info(log_msg)
344
 
345
  except Exception as e:
346
- error_msg = f"Erro no treinamento: {str(e)}"
347
  logger.error(error_msg)
348
- if job_id in self.training_jobs:
349
- self.training_jobs[job_id]["status"] = "error"
350
- self.training_jobs[job_id]["error"] = error_msg
351
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ❌ {error_msg}")
352
- self._save_job_state(job_id)
353
 
354
  def start_training(self,
355
  model_name: str,
356
  image_files: List[str],
357
  captions: List[str],
358
  **kwargs) -> str:
 
 
359
  job_id = str(uuid.uuid4())
360
 
 
361
  dataset = self.prepare_image_dataset(image_files, captions, kwargs.get('resolution', 512))
362
 
363
  self.training_jobs[job_id] = {
@@ -373,8 +349,7 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
373
  "completed_at": None
374
  }
375
 
376
- self._save_job_state(job_id)
377
-
378
  thread = threading.Thread(
379
  target=self.real_lora_training,
380
  args=(job_id, model_name, dataset),
@@ -385,7 +360,12 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
385
 
386
  return job_id
387
 
 
 
 
 
388
  def list_trained_models(self) -> List[Dict[str, str]]:
 
389
  models = []
390
  lora_models_dir = Path("./lora_models")
391
 
@@ -397,6 +377,7 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
397
  try:
398
  with open(config_file, 'r') as f:
399
  config = json.load(f)
 
400
  models.append({
401
  "id": model_dir.name,
402
  "path": str(model_dir),
@@ -405,7 +386,6 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
405
  "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
406
  })
407
  except Exception as e:
408
- logger.error(f"Erro ao ler config de {model_dir.name}: {e}")
409
  models.append({
410
  "id": model_dir.name,
411
  "path": str(model_dir),
@@ -413,140 +393,531 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
413
  "r": "Unknown",
414
  "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
415
  })
 
416
  return models
417
 
418
  def create_download_zip(self, model_path: str) -> str:
 
419
  zip_path = f"{model_path}.zip"
 
420
  with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
421
  model_dir = Path(model_path)
422
  for file_path in model_dir.rglob('*'):
423
  if file_path.is_file():
424
  arcname = file_path.relative_to(model_dir)
425
  zipf.write(file_path, arcname)
 
426
  return zip_path
427
 
428
 
 
429
  trainer = LoRAImageTrainer()
430
 
431
  def create_gradio_interface():
 
 
 
432
  custom_css = """
 
433
  @media (max-width: 768px) {
434
- .gradio-container { padding: 8px !important; }
435
- .btn { width: 100% !important; padding: 12px !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  }
 
 
437
  .lora-header {
438
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
439
- color: white; padding: 20px; border-radius: 12px; margin-bottom: 20px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  }
441
- .status-indicator { padding: 4px 8px; border-radius: 6px; font-size: 12px; font-weight: 600; }
 
 
 
 
 
 
442
  .status-completed { background-color: #34d399; color: #065f46; }
443
  .status-error { background-color: #f87171; color: #991b1b; }
 
 
 
 
 
 
 
 
 
 
 
 
 
444
  """
445
 
446
- def start_training_wrapper(model_name, files, captions_text, trigger_word, r, lora_alpha, num_epochs, learning_rate):
447
- if not files: return "❌ Erro: Nenhuma imagem enviada!"
448
- if len(files) < 3: return "❌ Forneça pelo menos 3 imagens!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
 
450
  try:
 
451
  image_files = [f.name for f in files]
452
- captions = [line.strip() for line in captions_text.split('\n') if line.strip()] if captions_text.strip() else []
 
 
 
 
 
 
453
  while len(captions) < len(files):
454
- captions.append(f"{trigger_word.strip() or 'training image'}, high quality photo" if trigger_word.strip() else f"training image {len(captions) + 1}, high quality photo")
 
 
 
 
 
455
  captions = captions[:len(files)]
456
 
457
  job_id = trainer.start_training(
458
- model_name=model_name,
459
- image_files=image_files,
460
  captions=captions,
461
- r=int(r),
462
- lora_alpha=int(lora_alpha),
463
- lora_dropout=0.0, # Fixado
464
- num_epochs=int(num_epochs),
465
  learning_rate=float(learning_rate),
466
- batch_size=1, # Fixado
467
- resolution=512 # Fixado
468
  )
469
- return f"✅ Treinamento iniciado! ID: {job_id}\n📊 Imagens: {len(files)}\nUse este ID para acompanhar o progresso."
 
 
470
  except Exception as e:
471
- return f"❌ Erro: {str(e)}"
472
 
473
  def check_status_wrapper(job_id):
474
- if not job_id.strip(): return "❌ Forneça um ID válido!"
 
 
 
475
  status = trainer.get_training_status(job_id.strip())
476
- if "error" in status: return "❌ Job não encontrado!"
477
 
478
- status_emoji = {'completed': '✅', 'error': '❌', 'training': '🏋️', 'queued': '⏳'}.get(status['status'], '📊')
479
- progress = status['progress']
480
- progress_bar = f'<div style="width:100%;background:#e5e7eb;border-radius:4px;overflow:hidden;margin:8px 0;"><div style="width:{progress}%;height:8px;background:linear-gradient(90deg,#3b82f6,#8b5cf6);border-radius:4px;"></div></div>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
 
482
- status_text = f"🆔 Job ID: {status['id']}\n{status_emoji} Status: {status['status'].upper()}\n⏳ Progresso: {progress}%\n{progress_bar}\n🤖 Modelo: {status['model_name']}\n🖼️ Imagens: {status.get('num_images','N/A')}\n📅 Criado: {status['created_at']}\n"
483
- if status['logs']: status_text += "\n📝 Logs:\n" + "\n".join([f"• {log}" for log in status['logs'][-5:]])
484
- if status['status'] == 'completed': status_text += f"\n✅ Concluído! Modelo salvo em: {status['model_path']}"
485
- elif status['status'] == 'error': status_text += f"\n❌ Erro: {status['error']}"
486
  return status_text
487
 
488
  def list_models_wrapper():
 
489
  models = trainer.list_trained_models()
490
- if not models: return "📭 Nenhum modelo encontrado."
491
- return "\n\n".join([f"🆔 {m['id']}\n🤖 {m['base_model']}\n📅 {m['created']}" for m in models])
 
 
 
 
 
 
 
 
 
 
 
 
492
 
493
  def download_model_wrapper(job_id):
494
- if not job_id.strip(): return None, "❌ ID inválido!"
 
 
 
495
  status = trainer.get_training_status(job_id.strip())
496
- if status.get("status") != "completed": return None, "❌ Treinamento não concluído!"
 
 
 
 
 
 
497
  try:
498
- zip_path = trainer.create_download_zip(status['model_path'])
499
- return zip_path, "✅ ZIP criado! Clique acima para baixar."
 
 
 
500
  except Exception as e:
501
- return None, f"❌ Erro: {str(e)}"
502
 
503
- with gr.Blocks(title="🎨 LoRA Image Trainer", theme=gr.themes.Soft(), css=custom_css) as interface:
504
- gr.HTML('<div class="lora-header"><h1>🎨 LoRA Image Trainer</h1><p>Treine seu LoRA para imagens</p></div>')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
 
506
  with gr.Tabs():
507
- with gr.TabItem("🎯 Treinar"):
 
 
 
 
508
  with gr.Row():
509
- with gr.Column():
510
- model_dropdown = gr.Dropdown(trainer.get_available_models(), value="runwayml/stable-diffusion-v1-5", label="🤖 Modelo Base")
511
- image_files = gr.File(file_count="multiple", file_types=["image"], label="🖼️ Imagens")
512
- trigger_word = gr.Textbox(label="🏷️ Trigger Word", placeholder="ex: meuEstilo")
513
- captions_text = gr.Textbox(lines=4, placeholder="Legenda por linha...", label="📝 Legendas (Opcional)")
514
- with gr.Column():
515
- r = gr.Slider(4, 32, 8, step=4, label="r (Rank)")
516
- lora_alpha = gr.Slider(1, 32, 16, step=1, label="LoRA Alpha")
517
- num_epochs = gr.Slider(1, 10, 5, step=1, label="Épocas")
518
- learning_rate = gr.Slider(1e-5, 1e-3, 1e-4, step=1e-5, label="Taxa de Aprendizado")
519
- train_button = gr.Button("🚀 Iniciar Treinamento", variant="primary")
520
- train_output = gr.Textbox(label="📊 Resultado")
521
- train_button.click(start_training_wrapper, [model_dropdown, image_files, captions_text, trigger_word, r, lora_alpha, num_epochs, learning_rate], train_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
- with gr.TabItem("📊 Status"):
524
- job_id_input = gr.Textbox(label="🆔 ID do Job")
525
- status_button = gr.Button("🔍 Verificar Status")
526
- status_output = gr.Textbox(label="📈 Status", lines=10)
527
- status_button.click(check_status_wrapper, job_id_input, status_output)
 
 
 
 
528
 
529
- with gr.TabItem("📚 Download"):
530
- download_job_id = gr.Textbox(label="🆔 ID do Job")
531
- download_button = gr.Button("📦 Baixar Modelo")
532
- download_file = gr.File(label="📁 Arquivo ZIP")
533
- download_status = gr.Textbox(label="📊 Status")
534
- download_button.click(download_model_wrapper, download_job_id, [download_file, download_status])
535
-
536
- gr.Markdown("---\n<center>🎨 LoRA Image Trainer v1.0 | Treinamento Real de LoRA</center>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
 
538
  return interface
539
 
 
540
  if __name__ == "__main__":
541
- # Garantir que os diretórios existam
542
  os.makedirs("./lora_models", exist_ok=True)
543
- os.makedirs("./jobs", exist_ok=True)
544
- logger.info("Aplicação iniciada. Diretórios verificados.")
545
 
 
546
  interface = create_gradio_interface()
 
 
547
  interface.launch(
548
  server_name="0.0.0.0",
549
  server_port=7860,
 
550
  show_error=True,
551
  quiet=False
552
  )
 
35
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  self.training_jobs = {}
37
  self.models_cache = {}
 
 
 
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def get_available_models(self) -> List[str]:
40
+ """Retorna lista de modelos base disponíveis para treinamento LoRA."""
41
  return [
42
  "runwayml/stable-diffusion-v1-5",
43
  "stabilityai/stable-diffusion-2-1",
44
  "CompVis/stable-diffusion-v1-4"
45
+ # XL removido por ser pesado demais para Spaces gratuitos
46
  ]
47
 
48
  def load_base_model(self, model_name: str):
49
+ """Carrega modelo base de difusão com otimizações para baixo uso de GPU."""
50
  try:
51
  if model_name in self.models_cache:
52
  return self.models_cache[model_name]
53
 
54
  logger.info(f"Carregando modelo base: {model_name}")
55
 
56
+ # Configurações para otimização de memória
57
  model_kwargs = {
58
  "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
59
  "use_safetensors": True,
60
  "variant": "fp16" if torch.cuda.is_available() else None,
61
+ "safety_checker": None, # Desativa verificador de segurança para economizar memória
62
  }
63
 
64
+ # Carregar pipeline completo
65
  pipeline = StableDiffusionPipeline.from_pretrained(
66
  model_name,
67
  **model_kwargs
 
69
 
70
  if torch.cuda.is_available():
71
  pipeline = pipeline.to(self.device)
72
+ # Habilitar attention slicing para economia de memória
73
  pipeline.enable_attention_slicing()
74
+ # Habilitar memory efficient attention se disponível
75
  try:
76
  pipeline.enable_xformers_memory_efficient_attention()
77
+ except Exception as e:
78
+ logger.warning("xformers não disponível, usando attention padrão")
 
79
 
80
+ # Cache do modelo
81
  self.models_cache[model_name] = pipeline
82
+
83
  return pipeline
84
 
85
  except Exception as e:
 
87
  raise e
88
 
89
  def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:
90
+ """Prepara dataset de imagens para treinamento."""
91
  dataset = []
92
 
93
  for img_path, caption in zip(image_files, captions):
94
  try:
95
+ # Carregar e redimensionar imagem
96
  image = Image.open(img_path).convert("RGB")
97
+
98
+ # Redimensionar mantendo aspect ratio
99
  image = self.resize_image(image, resolution)
100
 
101
  dataset.append({
 
111
  return dataset
112
 
113
  def resize_image(self, image: Image.Image, target_size: int) -> Image.Image:
114
+ """Redimensiona imagem mantendo aspect ratio e fazendo crop central se necessário."""
115
  width, height = image.size
116
 
117
+ # Calcular novo tamanho mantendo aspect ratio
118
  if width > height:
119
  new_width = target_size
120
  new_height = int((height * target_size) / width)
 
122
  new_height = target_size
123
  new_width = int((width * target_size) / height)
124
 
125
+ # Redimensionar
126
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
127
 
128
+ # Crop central para obter tamanho exato
129
  if new_width != target_size or new_height != target_size:
130
  left = (new_width - target_size) // 2
131
  top = (new_height - target_size) // 2
132
  right = left + target_size
133
  bottom = top + target_size
134
+
135
  image = image.crop((left, top, right, bottom))
136
 
137
  return image
 
140
  job_id: str,
141
  model_name: str,
142
  dataset: List[Dict],
143
+ r: int = 16,
144
+ lora_alpha: int = 32,
145
+ lora_dropout: float = 0.1,
146
+ num_epochs: int = 10,
147
  learning_rate: float = 1e-4,
148
  batch_size: int = 1,
149
  resolution: int = 512) -> None:
150
+ """TREINAMENTO REAL DE LoRA PARA IMAGENS - CORRIGIDO PARA DIFFUSERS + PEFT."""
151
+
152
  try:
153
+ # Atualizar status
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  self.training_jobs[job_id]["status"] = "loading_model"
155
  self.training_jobs[job_id]["progress"] = 5
156
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Carregando modelo base: {model_name}")
157
+
158
+ # Carregar modelo base
 
 
159
  pipeline = self.load_base_model(model_name)
160
  unet = pipeline.unet
161
  text_encoder = pipeline.text_encoder
 
163
  tokenizer = pipeline.tokenizer
164
  scheduler = pipeline.scheduler
165
 
166
+ # Congelar parâmetros
167
  unet.requires_grad_(False)
168
  text_encoder.requires_grad_(False)
169
  vae.requires_grad_(False)
170
 
171
+ # Criar configuração LoRA
 
 
 
 
172
  lora_config = LoraConfig(
173
  r=r,
174
  lora_alpha=lora_alpha,
 
177
  bias="none"
178
  )
179
 
180
+ # Aplicar LoRA ao UNet manualmente, sem usar get_peft_model diretamente
181
  unet.add_adapter(lora_config, adapter_name="default")
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ # Ativar o adaptador
184
+ unet.set_adapter("default")
185
  unet.train()
186
  unet.to(self.device)
187
 
188
+ # Otimizador
189
+ optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
190
 
191
+ # Preparar scheduler para treinamento
192
  self.training_jobs[job_id]["status"] = "preparing_data"
193
  self.training_jobs[job_id]["progress"] = 20
 
194
 
195
+ # Normalização de imagem
196
  def preprocess_image(image):
197
  image = np.array(image).astype(np.float32) / 255.0
198
  image = image.transpose(2, 0, 1)
199
  image = torch.from_numpy(image).unsqueeze(0)
200
  return image
201
 
202
+ # Loop de treinamento real
203
  total_steps = num_epochs * len(dataset)
204
  current_step = 0
205
 
206
  self.training_jobs[job_id]["status"] = "training"
207
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Iniciando treinamento real...")
 
 
 
208
 
209
  for epoch in range(num_epochs):
210
  for item in dataset:
211
  current_step += 1
212
 
213
+ # Obter imagem e legenda
214
  image = item["image"]
215
  caption = item["caption"]
216
+
217
+ # Pré-processar imagem
218
  image_tensor = preprocess_image(image).to(self.device)
219
  if torch.cuda.is_available():
220
  image_tensor = image_tensor.half()
221
 
222
+ # Codificar imagem para latentes
223
  with torch.no_grad():
224
  latents = vae.encode(image_tensor * 2 - 1).latent_dist.sample() * 0.18215
225
 
226
+ # Tokenizar texto
227
  inputs = tokenizer(caption, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
228
  input_ids = inputs.input_ids.to(self.device)
229
 
230
+ # Gerar timesteps aleatórios
231
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (1,), device=self.device).long()
232
+
233
+ # Adicionar ruído aos latentes
234
  noise = torch.randn_like(latents)
235
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
236
 
237
+ # Forward pass
238
  encoder_hidden_states = text_encoder(input_ids)[0]
239
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
240
 
241
+ # Calcular perda
242
  loss = torch.nn.functional.mse_loss(noise_pred, noise)
243
+
244
+ # Backward pass
245
  optimizer.zero_grad()
246
  loss.backward()
247
  optimizer.step()
248
 
249
+ # Atualizar progresso
 
 
250
  progress = 30 + int((current_step / total_steps) * 60)
251
  self.training_jobs[job_id]["progress"] = min(progress, 90)
252
 
253
  if current_step % max(1, len(dataset)//2) == 0:
254
  log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
255
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - {log_msg}")
 
 
256
 
257
+ # Salvar LoRA
258
  self.training_jobs[job_id]["status"] = "saving"
259
  self.training_jobs[job_id]["progress"] = 95
 
260
 
261
  output_dir = f"./lora_models/{job_id}"
262
  os.makedirs(output_dir, exist_ok=True)
263
 
264
+ # Salvar apenas os adaptadores LoRA
265
+ unet.save_pretrained(output_dir)
 
 
 
266
 
267
+ # Criar adapter_config.json
268
  lora_config_dict = {
269
  "r": r,
270
  "lora_alpha": lora_alpha,
271
  "target_modules": ["to_k", "to_q", "to_v", "to_out.0"],
272
  "lora_dropout": lora_dropout,
273
  "bias": "none",
274
+ "task_type": "CAUSAL_LM", # Mantido por compatibilidade, mas não é usado
275
  "base_model_name": model_name,
276
  "training_info": {
277
  "num_epochs": num_epochs,
 
284
  with open(f"{output_dir}/adapter_config.json", "w") as f:
285
  json.dump(lora_config_dict, f, indent=2)
286
 
287
+ # README
288
  readme_content = f"""# LoRA Model - {job_id}
289
+
290
+ Informações do Treinamento
291
+
292
  Modelo Base: {model_name}
293
+ Rank (r): {r}
294
+ LoRA Alpha: {lora_alpha}
295
+ Dropout: {lora_dropout}
296
+ Épocas: {num_epochs}
297
+ Taxa de Aprendizado: {learning_rate}
298
+ Resolução: {resolution}x{resolution}
299
+ Número de Imagens: {len(dataset)}
300
+ Data de Treinamento: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
301
+
302
+ Como Usar
303
+
304
+ 1. Baixe os arquivos adapter_config.json e adapter_model.safetensors
305
+ 2. Carregue em sua ferramenta de geração de imagens favorita (ComfyUI, Automatic1111, etc.)
306
+ 3. Use o trigger word ou estilo aprendido durante o treinamento
307
  """
308
  with open(f"{output_dir}/README.md", "w") as f:
309
  f.write(readme_content)
310
 
311
+ # Finalizar
312
  self.training_jobs[job_id]["status"] = "completed"
313
  self.training_jobs[job_id]["progress"] = 100
314
  self.training_jobs[job_id]["model_path"] = output_dir
315
  self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
316
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ✅ Treinamento REAL concluído! LoRA salvo em {output_dir}")
317
+
318
+ logger.info(f"Treinamento LoRA REAL concluído para job {job_id}")
 
319
 
320
  except Exception as e:
321
+ error_msg = f"Erro REAL no treinamento: {str(e)}"
322
  logger.error(error_msg)
323
+ self.training_jobs[job_id]["status"] = "error"
324
+ self.training_jobs[job_id]["error"] = error_msg
325
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ❌ {error_msg}")
 
 
326
 
327
  def start_training(self,
328
  model_name: str,
329
  image_files: List[str],
330
  captions: List[str],
331
  **kwargs) -> str:
332
+ """Inicia treinamento LoRA assíncrono."""
333
+
334
  job_id = str(uuid.uuid4())
335
 
336
+ # Preparar dataset
337
  dataset = self.prepare_image_dataset(image_files, captions, kwargs.get('resolution', 512))
338
 
339
  self.training_jobs[job_id] = {
 
349
  "completed_at": None
350
  }
351
 
352
+ # Iniciar treinamento em thread separada
 
353
  thread = threading.Thread(
354
  target=self.real_lora_training,
355
  args=(job_id, model_name, dataset),
 
360
 
361
  return job_id
362
 
363
+ def get_training_status(self, job_id: str) -> Dict[str, Any]:
364
+ """Retorna status do treinamento."""
365
+ return self.training_jobs.get(job_id, {"error": "Job não encontrado"})
366
+
367
  def list_trained_models(self) -> List[Dict[str, str]]:
368
+ """Lista modelos LoRA treinados."""
369
  models = []
370
  lora_models_dir = Path("./lora_models")
371
 
 
377
  try:
378
  with open(config_file, 'r') as f:
379
  config = json.load(f)
380
+
381
  models.append({
382
  "id": model_dir.name,
383
  "path": str(model_dir),
 
386
  "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
387
  })
388
  except Exception as e:
 
389
  models.append({
390
  "id": model_dir.name,
391
  "path": str(model_dir),
 
393
  "r": "Unknown",
394
  "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
395
  })
396
+
397
  return models
398
 
399
  def create_download_zip(self, model_path: str) -> str:
400
+ """Cria um arquivo ZIP com os arquivos do modelo LoRA para download."""
401
  zip_path = f"{model_path}.zip"
402
+
403
  with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
404
  model_dir = Path(model_path)
405
  for file_path in model_dir.rglob('*'):
406
  if file_path.is_file():
407
  arcname = file_path.relative_to(model_dir)
408
  zipf.write(file_path, arcname)
409
+
410
  return zip_path
411
 
412
 
413
+ # Instância global do trainer
414
  trainer = LoRAImageTrainer()
415
 
416
  def create_gradio_interface():
417
+ """Cria interface Gradio para a ferramenta LoRA de geração de imagens."""
418
+
419
+ # CSS personalizado para responsividade móvel
420
  custom_css = """
421
+ /* Mobile-first responsive design */
422
  @media (max-width: 768px) {
423
+ .gradio-container {
424
+ padding: 8px !important;
425
+ margin: 0 !important;
426
+ }
427
+
428
+ .tab-nav {
429
+ flex-wrap: wrap !important;
430
+ gap: 4px !important;
431
+ }
432
+
433
+ .tab-nav button {
434
+ font-size: 14px !important;
435
+ padding: 8px 12px !important;
436
+ min-width: auto !important;
437
+ flex: 1 1 auto !important;
438
+ }
439
+
440
+ .form-container {
441
+ padding: 12px !important;
442
+ }
443
+
444
+ .btn {
445
+ width: 100% !important;
446
+ padding: 12px !important;
447
+ font-size: 16px !important;
448
+ margin-bottom: 8px !important;
449
+ min-height: 44px !important;
450
+ }
451
+
452
+ .textbox textarea {
453
+ font-size: 16px !important;
454
+ min-height: 120px !important;
455
+ }
456
+
457
+ .dropdown select {
458
+ font-size: 16px !important;
459
+ padding: 12px !important;
460
+ }
461
+
462
+ .output-text {
463
+ font-size: 14px !important;
464
+ line-height: 1.5 !important;
465
+ }
466
+
467
+ .column {
468
+ margin-bottom: 16px !important;
469
+ }
470
+
471
+ .file-upload {
472
+ min-height: 100px !important;
473
+ }
474
  }
475
+
476
+ /* Enhanced visual styles */
477
  .lora-header {
478
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
479
+ color: white;
480
+ padding: 20px;
481
+ border-radius: 12px;
482
+ margin-bottom: 20px;
483
+ text-align: center;
484
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
485
+ }
486
+
487
+ .status-indicator {
488
+ display: inline-block;
489
+ padding: 4px 8px;
490
+ border-radius: 6px;
491
+ font-size: 12px;
492
+ font-weight: 600;
493
+ text-transform: uppercase;
494
+ letter-spacing: 0.5px;
495
+ margin-right: 8px;
496
  }
497
+
498
+ .status-queued { background-color: #fbbf24; color: #92400e; }
499
+ .status-loading_model { background-color: #60a5fa; color: #1e40af; }
500
+ .status-preparing_lora { background-color: #8b5cf6; color: #5b21b6; }
501
+ .status-preparing_data { background-color: #06b6d4; color: #0e7490; }
502
+ .status-training { background-color: #a78bfa; color: #5b21b6; }
503
+ .status-saving { background-color: #f59e0b; color: #92400e; }
504
  .status-completed { background-color: #34d399; color: #065f46; }
505
  .status-error { background-color: #f87171; color: #991b1b; }
506
+
507
+ /* Touch device optimizations */
508
+ @media (hover: none) and (pointer: coarse) {
509
+ .btn {
510
+ min-height: 44px !important;
511
+ min-width: 44px !important;
512
+ }
513
+
514
+ .tab-nav button {
515
+ min-height: 44px !important;
516
+ min-width: 44px !important;
517
+ }
518
+ }
519
  """
520
 
521
+ def process_images_and_captions(files, captions_text):
522
+ """Processa imagens e legendas enviadas pelo usuário."""
523
+ if not files:
524
+ return "❌ Erro: Nenhuma imagem foi enviada!"
525
+
526
+ # Processar legendas
527
+ captions = []
528
+ if captions_text.strip():
529
+ captions = [line.strip() for line in captions_text.split('\n') if line.strip()]
530
+
531
+ # Se não há legendas suficientes, usar legendas padrão
532
+ while len(captions) < len(files):
533
+ captions.append(f"training image {len(captions) + 1}")
534
+
535
+ # Truncar legendas se houver mais que imagens
536
+ captions = captions[:len(files)]
537
+
538
+ return files, captions
539
+
540
+ def start_training_wrapper(model_name, files, captions_text, trigger_word, r, lora_alpha, lora_dropout,
541
+ num_epochs, learning_rate, batch_size, resolution):
542
+ """Wrapper para iniciar treinamento via Gradio."""
543
+
544
+ if not files:
545
+ return "❌ Erro: Nenhuma imagem foi enviada para treinamento!"
546
+
547
+ if len(files) < 3:
548
+ return "❌ Erro: Forneça pelo menos 3 imagens para treinamento!"
549
 
550
  try:
551
+ # Processar imagens e legendas
552
  image_files = [f.name for f in files]
553
+
554
+ # Processar legendas
555
+ captions = []
556
+ if captions_text.strip():
557
+ captions = [line.strip() for line in captions_text.split('\n') if line.strip()]
558
+
559
+ # Se não há legendas suficientes, usar trigger word + descrição padrão
560
  while len(captions) < len(files):
561
+ if trigger_word.strip():
562
+ captions.append(f"{trigger_word.strip()}, high quality photo")
563
+ else:
564
+ captions.append(f"training image {len(captions) + 1}, high quality photo")
565
+
566
+ # Truncar legendas se houver mais que imagens
567
  captions = captions[:len(files)]
568
 
569
  job_id = trainer.start_training(
570
+ model_name=model_name,
571
+ image_files=image_files,
572
  captions=captions,
573
+ r=int(r),
574
+ lora_alpha=int(lora_alpha),
575
+ lora_dropout=float(lora_dropout),
576
+ num_epochs=int(num_epochs),
577
  learning_rate=float(learning_rate),
578
+ batch_size=int(batch_size),
579
+ resolution=int(resolution)
580
  )
581
+
582
+ return f"✅ Treinamento REAL iniciado! ID do Job: {job_id}\n\n📊 Imagens: {len(files)}\n🏷️ Trigger Word: {trigger_word or 'Nenhuma'}\n\nUse o ID acima para verificar o progresso na aba 'Status do Treinamento'."
583
+
584
  except Exception as e:
585
+ return f"❌ Erro ao iniciar treinamento: {str(e)}"
586
 
587
  def check_status_wrapper(job_id):
588
+ """Wrapper para verificar status via Gradio."""
589
+ if not job_id.strip():
590
+ return "❌ Erro: Forneça um ID de job válido!"
591
+
592
  status = trainer.get_training_status(job_id.strip())
 
593
 
594
+ if "error" in status and status["error"] == "Job não encontrado":
595
+ return "❌ Job não encontrado! Verifique o ID."
596
+
597
+ # Criar indicador visual de status
598
+ status_class = f"status-{status['status']}"
599
+ status_emoji = {
600
+ 'queued': '⏳',
601
+ 'loading_model': '📥',
602
+ 'preparing_lora': '⚙️',
603
+ 'preparing_data': '📊',
604
+ 'training': '🏋️',
605
+ 'saving': '💾',
606
+ 'completed': '✅',
607
+ 'error': '❌'
608
+ }.get(status['status'], '📊')
609
+
610
+ # Barra de progresso visual
611
+ progress = status['progress']
612
+ progress_bar = f"""
613
+ <div style="width: 100%; background-color: #e5e7eb; border-radius: 4px; overflow: hidden; margin: 8px 0;">
614
+ <div style="width: {progress}%; height: 8px; background: linear-gradient(90deg, #3b82f6, #8b5cf6); transition: width 0.3s ease; border-radius: 4px;"></div>
615
+ </div>
616
+ """
617
+
618
+ status_text = f"""
619
+
620
+ 📊 Status do Treinamento LoRA
621
+
622
+ 🆔 Job ID: {status['id']}
623
+ {status_emoji} Status: <span class="{status_class}">{status['status'].upper().replace('_', ' ')}</span>
624
+ ⏳ Progresso: {status['progress']}%
625
+
626
+ {progress_bar}
627
+
628
+ 🤖 Modelo Base: {status['model_name']}
629
+ 🖼️ Imagens: {status.get('num_images', 'N/A')}
630
+ 📅 Criado em: {status['created_at']}
631
+
632
+ """
633
+
634
+ if status['logs']:
635
+ status_text += "📝 **Logs Recentes:**\n"
636
+ for log in status['logs'][-5:]: # Últimos 5 logs
637
+ status_text += f"• {log}\n"
638
+
639
+ if status['status'] == 'completed':
640
+ status_text += f"\n✅ **Treinamento Concluído!**\n📁 **Modelo salvo em:** {status['model_path']}"
641
+ status_text += f"\n⏰ **Concluído em:** {status['completed_at']}"
642
+ status_text += f"\n\n💡 **Próximos passos:** Vá para a aba 'Modelos Treinados' para baixar seu LoRA!"
643
+ elif status['status'] == 'error':
644
+ status_text += f"\n❌ **Erro:** {status['error']}"
645
 
 
 
 
 
646
  return status_text
647
 
648
  def list_models_wrapper():
649
+ """Wrapper para listar modelos via Gradio."""
650
  models = trainer.list_trained_models()
651
+
652
+ if not models:
653
+ return "📭 Nenhum modelo LoRA treinado encontrado."
654
+
655
+ models_text = "📚 **Modelos LoRA Treinados:**\n\n"
656
+ for model in models:
657
+ models_text += f"🆔 **ID:** {model['id']}\n"
658
+ models_text += f"🤖 **Modelo Base:** {model['base_model']}\n"
659
+ models_text += f"📊 **Rank (r):** {model['r']}\n"
660
+ models_text += f"📁 **Caminho:** {model['path']}\n"
661
+ models_text += f"📅 **Criado:** {model['created']}\n\n"
662
+ models_text += "---\n\n"
663
+
664
+ return models_text
665
 
666
  def download_model_wrapper(job_id):
667
+ """Wrapper para preparar download do modelo."""
668
+ if not job_id.strip():
669
+ return None, "❌ Erro: Forneça um ID de job válido!"
670
+
671
  status = trainer.get_training_status(job_id.strip())
672
+
673
+ if "error" in status and status["error"] == "Job não encontrado":
674
+ return None, "❌ Job não encontrado! Verifique o ID."
675
+
676
+ if status['status'] != 'completed':
677
+ return None, f"❌ Treinamento ainda não foi concluído. Status atual: {status['status']}"
678
+
679
  try:
680
+ model_path = status['model_path']
681
+ zip_path = trainer.create_download_zip(model_path)
682
+
683
+ return zip_path, f"✅ Arquivo ZIP criado com sucesso! Clique no link acima para baixar."
684
+
685
  except Exception as e:
686
+ return None, f"❌ Erro ao criar arquivo de download: {str(e)}"
687
 
688
+ # Interface Gradio
689
+ with gr.Blocks(
690
+ title="🎨 LoRA Image Trainer - Criador e Treinador de LoRA para Imagens",
691
+ theme=gr.themes.Soft(),
692
+ css=custom_css
693
+ ) as interface:
694
+
695
+ gr.HTML("""
696
+ <div class="lora-header">
697
+ <h1>🎨 LoRA Image Trainer</h1>
698
+ <p>Criador e Treinador de LoRA para Geração de Imagens</p>
699
+ <p style="font-size: 0.9em; opacity: 0.9; margin-top: 8px;">
700
+ Ferramenta otimizada para baixo uso de GPU, compatível com dispositivos móveis
701
+ </p>
702
+ </div>
703
+ """)
704
 
705
  with gr.Tabs():
706
+
707
+ # Aba de Treinamento
708
+ with gr.TabItem("🎯 Treinar LoRA"):
709
+ gr.Markdown("### Configurar e Iniciar Treinamento LoRA para Imagens")
710
+
711
  with gr.Row():
712
+ with gr.Column(scale=2):
713
+ model_dropdown = gr.Dropdown(
714
+ choices=trainer.get_available_models(),
715
+ value="runwayml/stable-diffusion-v1-5",
716
+ label="🤖 Modelo Base",
717
+ )
718
+
719
+ image_files = gr.File(
720
+ file_count="multiple",
721
+ file_types=["image"],
722
+ label="🖼️ Imagens de Treinamento",
723
+ )
724
+
725
+ trigger_word = gr.Textbox(
726
+ label="🏷️ Trigger Word (Opcional)",
727
+ placeholder="ex: meuEstilo, minhaPersonagem, etc.",
728
+ )
729
+
730
+ captions_text = gr.Textbox(
731
+ lines=8,
732
+ placeholder="Digite uma legenda por linha (opcional)...\n\nExemplo:\nmeuEstilo, retrato de uma mulher\nmeuEstilo, homem sorrindo\nmeuEstilo, paisagem urbana\n\nSe deixar vazio, usará a trigger word + 'high quality photo'",
733
+ label="📝 Legendas das Imagens (Opcional)",
734
+ )
735
+
736
+ with gr.Column(scale=1):
737
+ gr.Markdown("### ⚙️ Parâmetros LoRA")
738
+
739
+ r = gr.Slider(
740
+ minimum=4, maximum=64, value=8, step=4, # reduzido max para 64
741
+ label="r (Rank)",
742
+ )
743
+
744
+ lora_alpha = gr.Slider(
745
+ minimum=1, maximum=64, value=16, step=1, # reduzido max para 64
746
+ label="LoRA Alpha",
747
+ )
748
+
749
+ lora_dropout = gr.Slider(
750
+ minimum=0.0, maximum=0.5, value=0.0, step=0.05, # dropout 0 para mais estabilidade
751
+ label="LoRA Dropout",
752
+ )
753
+
754
+ gr.Markdown("### 🏋️ Parâmetros de Treinamento")
755
+
756
+ num_epochs = gr.Slider(
757
+ minimum=5, maximum=20, value=10, step=5, # reduzido max para 20
758
+ label="Épocas",
759
+ )
760
+
761
+ learning_rate = gr.Slider(
762
+ minimum=1e-5, maximum=5e-4, value=1e-4, step=1e-5, # reduzido max
763
+ label="Taxa de Aprendizado",
764
+ )
765
+
766
+ batch_size = gr.Slider(
767
+ minimum=1, maximum=1, value=1, step=1, # fixado em 1 para Spaces
768
+ label="Batch Size",
769
+ )
770
+
771
+ resolution = gr.Dropdown(
772
+ choices=[512], # fixado em 512 para garantir funcionamento em GPU limitada
773
+ value=512,
774
+ label="Resolução",
775
+ )
776
 
777
+ train_button = gr.Button("🚀 Iniciar Treinamento LoRA", variant="primary", size="lg")
778
+ train_output = gr.Textbox(label="📊 Resultado", lines=5)
779
+
780
+ train_button.click(
781
+ start_training_wrapper,
782
+ inputs=[model_dropdown, image_files, captions_text, trigger_word, r, lora_alpha, lora_dropout,
783
+ num_epochs, learning_rate, batch_size, resolution],
784
+ outputs=train_output
785
+ )
786
 
787
+ # Aba de Status
788
+ with gr.TabItem("📊 Status do Treinamento"):
789
+ gr.Markdown("### Verificar Progresso do Treinamento")
790
+
791
+ job_id_input = gr.Textbox(
792
+ label="🆔 ID do Job",
793
+ placeholder="Cole aqui o ID do job de treinamento...",
794
+ )
795
+
796
+ status_button = gr.Button("🔍 Verificar Status", variant="secondary")
797
+ status_output = gr.Textbox(label="📈 Status", lines=12)
798
+
799
+ status_button.click(
800
+ check_status_wrapper,
801
+ inputs=job_id_input,
802
+ outputs=status_output
803
+ )
804
+
805
+ gr.Markdown("💡 **Dica:** Atualize o status regularmente para acompanhar o progresso do treinamento.")
806
+
807
+ # Aba de Modelos e Download
808
+ with gr.TabItem("📚 Modelos e Download"):
809
+ gr.Markdown("### Visualizar e Baixar Modelos LoRA Treinados")
810
+
811
+ with gr.Row():
812
+ with gr.Column(scale=1):
813
+ list_button = gr.Button("📋 Listar Modelos", variant="secondary")
814
+ models_output = gr.Textbox(label="📚 Modelos Disponíveis", lines=10)
815
+
816
+ list_button.click(
817
+ list_models_wrapper,
818
+ outputs=models_output
819
+ )
820
+
821
+ with gr.Column(scale=1):
822
+ gr.Markdown("#### 💾 Download de Modelo")
823
+
824
+ download_job_id = gr.Textbox(
825
+ label="🆔 ID do Job para Download",
826
+ placeholder="Cole o ID do job concluído...",
827
+ )
828
+
829
+ download_button = gr.Button("📦 Preparar Download", variant="primary")
830
+ download_file = gr.File(label="📁 Arquivo para Download")
831
+ download_status = gr.Textbox(label="📊 Status do Download", lines=3)
832
+
833
+ download_button.click(
834
+ download_model_wrapper,
835
+ inputs=download_job_id,
836
+ outputs=[download_file, download_status]
837
+ )
838
+
839
+ # Aba de Informações
840
+ with gr.TabItem("ℹ️ Sobre"):
841
+ gr.Markdown("""
842
+ ### 🎯 Sobre o LoRA Image Trainer
843
+
844
+ Esta ferramenta foi desenvolvida para democratizar o acesso ao treinamento de modelos LoRA para geração de imagens,
845
+ permitindo que qualquer pessoa possa criar adaptações personalizadas de modelos de difusão (como Stable Diffusion)
846
+ sem a necessidade de hardware especializado.
847
+
848
+ #### ✨ Características Principais:
849
+
850
+ - **🔋 Otimizado para Baixa GPU**: Utiliza técnicas como mixed precision, gradient checkpointing e configurações otimizadas
851
+ - **📱 Compatível com Móveis**: Interface responsiva que funciona em smartphones e tablets
852
+ - **⚡ Rápido e Eficiente**: Treinamento otimizado com bibliotecas Diffusers e PEFT do Hugging Face
853
+ - **🎛️ Configurável**: Controle total sobre parâmetros LoRA e de treinamento
854
+ - **☁️ Pronto para Deploy**: Facilmente implantável no Hugging Face Spaces
855
+ - **🎨 Focado em Imagens**: Especificamente projetado para modelos de difusão e geração de imagens
856
+
857
+ #### 🛠️ Tecnologias Utilizadas:
858
+
859
+ - **Hugging Face Diffusers**: Para modelos de difusão e pipeline de treinamento
860
+ - **PEFT (Parameter-Efficient Fine-Tuning)**: Para treinamento eficiente de LoRA
861
+ - **PyTorch**: Framework de deep learning
862
+ - **Gradio**: Interface web interativa e responsiva
863
+ - **LoRA (Low-Rank Adaptation)**: Técnica de fine-tuning eficiente para modelos de difusão
864
+
865
+ #### 📖 Como Usar:
866
+
867
+ 1. **Prepare suas imagens**: Colete 3-50 imagens de alta qualidade do estilo/conceito que deseja treinar
868
+ 2. **Escolha um modelo base** na aba "Treinar LoRA" (recomendado: Stable Diffusion 1.5)
869
+ 3. **Faça upload das imagens** e defina uma trigger word (palavra-chave)
870
+ 4. **Configure os parâmetros** conforme necessário (valores padrão funcionam bem)
871
+ 5. **Inicie o treinamento** e anote o ID do job
872
+ 6. **Acompanhe o progresso** na aba "Status do Treinamento"
873
+ 7. **Baixe seu LoRA** na aba "Modelos e Download" quando concluído
874
+ 8. **Use em suas ferramentas favoritas** (ComfyUI, Automatic1111, etc.)
875
+
876
+ #### 💡 Dicas para Melhores Resultados:
877
+
878
+ - **Qualidade > Quantidade**: 10-20 imagens de alta qualidade são melhores que 50 imagens ruins
879
+ - **Consistência**: Use imagens com estilo/conceito consistente
880
+ - **Resolução**: Para GPUs com pouca VRAM, use resolução 512x512
881
+ - **Trigger Word**: Escolha uma palavra única e fácil de lembrar
882
+ - **Legendas**: Descreva o que há nas imagens para melhor controle
883
+ - **Parâmetros**: Para iniciantes, use os valores padrão
884
+
885
+ #### 🎮 Compatibilidade:
886
+
887
+ Os LoRAs gerados são compatíveis com:
888
+ - **ComfyUI**: Carregue os arquivos .safetensors
889
+ - **Automatic1111**: Coloque na pasta models/Lora
890
+ - **SeaArt**: Faça upload do modelo
891
+ - **Outras ferramentas**: Qualquer ferramenta que suporte LoRA para Stable Diffusion
892
+
893
+ ---
894
+
895
+ **Desenvolvido com ❤️ para a comunidade de IA e arte digital**
896
+ """)
897
+
898
+ # Footer
899
+ gr.Markdown("""
900
+ ---
901
+ <div style="text-align: center; color: #666; font-size: 0.9em;">
902
+ 🎨 LoRA Image Trainer v1.0 | Otimizado para Baixa GPU | Compatível com Dispositivos Móveis
903
+ </div>
904
+ """)
905
 
906
  return interface
907
 
908
+ # Criar e configurar interface
909
  if __name__ == "__main__":
910
+ # Criar diretórios necessários
911
  os.makedirs("./lora_models", exist_ok=True)
 
 
912
 
913
+ # Configurar interface
914
  interface = create_gradio_interface()
915
+
916
+ # Lançar aplicação
917
  interface.launch(
918
  server_name="0.0.0.0",
919
  server_port=7860,
920
+ share=False,
921
  show_error=True,
922
  quiet=False
923
  )