Allex21 commited on
Commit
3e57297
·
verified ·
1 Parent(s): a82fdb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +555 -128
app.py CHANGED
@@ -35,31 +35,33 @@ class LoRAImageTrainer:
35
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  self.training_jobs = {}
37
  self.models_cache = {}
38
- # Criar diretórios
39
- os.makedirs("./lora_models", exist_ok=True)
40
- logger.info("Diretório ./lora_models criado.")
41
-
42
  def get_available_models(self) -> List[str]:
 
43
  return [
44
  "runwayml/stable-diffusion-v1-5",
45
  "stabilityai/stable-diffusion-2-1",
46
  "CompVis/stable-diffusion-v1-4"
 
47
  ]
48
 
49
  def load_base_model(self, model_name: str):
 
50
  try:
51
  if model_name in self.models_cache:
52
  return self.models_cache[model_name]
53
 
54
  logger.info(f"Carregando modelo base: {model_name}")
55
 
 
56
  model_kwargs = {
57
  "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
58
  "use_safetensors": True,
59
  "variant": "fp16" if torch.cuda.is_available() else None,
60
- "safety_checker": None,
61
  }
62
 
 
63
  pipeline = StableDiffusionPipeline.from_pretrained(
64
  model_name,
65
  **model_kwargs
@@ -67,14 +69,17 @@ class LoRAImageTrainer:
67
 
68
  if torch.cuda.is_available():
69
  pipeline = pipeline.to(self.device)
 
70
  pipeline.enable_attention_slicing()
 
71
  try:
72
  pipeline.enable_xformers_memory_efficient_attention()
73
- except:
74
- logger.warning("xformers não disponível")
75
- pipeline.unet.enable_gradient_checkpointing()
76
 
 
77
  self.models_cache[model_name] = pipeline
 
78
  return pipeline
79
 
80
  except Exception as e:
@@ -82,11 +87,15 @@ class LoRAImageTrainer:
82
  raise e
83
 
84
  def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:
 
85
  dataset = []
86
 
87
  for img_path, caption in zip(image_files, captions):
88
  try:
 
89
  image = Image.open(img_path).convert("RGB")
 
 
90
  image = self.resize_image(image, resolution)
91
 
92
  dataset.append({
@@ -102,8 +111,10 @@ class LoRAImageTrainer:
102
  return dataset
103
 
104
  def resize_image(self, image: Image.Image, target_size: int) -> Image.Image:
 
105
  width, height = image.size
106
 
 
107
  if width > height:
108
  new_width = target_size
109
  new_height = int((height * target_size) / width)
@@ -111,13 +122,16 @@ class LoRAImageTrainer:
111
  new_height = target_size
112
  new_width = int((width * target_size) / height)
113
 
 
114
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
115
 
 
116
  if new_width != target_size or new_height != target_size:
117
  left = (new_width - target_size) // 2
118
  top = (new_height - target_size) // 2
119
  right = left + target_size
120
  bottom = top + target_size
 
121
  image = image.crop((left, top, right, bottom))
122
 
123
  return image
@@ -126,35 +140,22 @@ class LoRAImageTrainer:
126
  job_id: str,
127
  model_name: str,
128
  dataset: List[Dict],
129
- r: int = 8,
130
- lora_alpha: int = 16,
131
- lora_dropout: float = 0.0,
132
- num_epochs: int = 5,
133
  learning_rate: float = 1e-4,
134
  batch_size: int = 1,
135
  resolution: int = 512) -> None:
 
 
136
  try:
137
- # Inicializar job
138
- if job_id not in self.training_jobs:
139
- self.training_jobs[job_id] = {
140
- "id": job_id,
141
- "status": "queued",
142
- "progress": 0,
143
- "created_at": datetime.now().isoformat(),
144
- "model_name": model_name,
145
- "num_images": len(dataset),
146
- "logs": [],
147
- "error": None,
148
- "model_path": None,
149
- "completed_at": None
150
- }
151
-
152
  self.training_jobs[job_id]["status"] = "loading_model"
153
  self.training_jobs[job_id]["progress"] = 5
154
- log_msg = f"{datetime.now().strftime('%H:%M:%S')} - Carregando modelo base: {model_name}"
155
- self.training_jobs[job_id]["logs"].append(log_msg)
156
- logger.info(log_msg)
157
-
158
  pipeline = self.load_base_model(model_name)
159
  unet = pipeline.unet
160
  text_encoder = pipeline.text_encoder
@@ -162,14 +163,12 @@ class LoRAImageTrainer:
162
  tokenizer = pipeline.tokenizer
163
  scheduler = pipeline.scheduler
164
 
 
165
  unet.requires_grad_(False)
166
  text_encoder.requires_grad_(False)
167
  vae.requires_grad_(False)
168
 
169
- # Remover adaptador se existir
170
- if hasattr(unet, "peft_config") and "default" in unet.peft_config:
171
- unet.delete_adapter("default")
172
-
173
  lora_config = LoraConfig(
174
  r=r,
175
  lora_alpha=lora_alpha,
@@ -178,96 +177,105 @@ class LoRAImageTrainer:
178
  bias="none"
179
  )
180
 
 
181
  unet.add_adapter(lora_config, adapter_name="default")
182
- unet.set_adapter("default")
183
-
184
- # Ativar apenas parâmetros do LoRA
185
- unet.requires_grad_(False)
186
- for name, param in unet.named_parameters():
187
- if "lora_" in name:
188
- param.requires_grad = True
189
 
 
 
190
  unet.train()
191
  unet.to(self.device)
192
 
193
- optimizer = torch.optim.AdamW([p for p in unet.parameters() if p.requires_grad], lr=learning_rate)
 
194
 
 
195
  self.training_jobs[job_id]["status"] = "preparing_data"
196
  self.training_jobs[job_id]["progress"] = 20
197
 
 
198
  def preprocess_image(image):
199
  image = np.array(image).astype(np.float32) / 255.0
200
  image = image.transpose(2, 0, 1)
201
  image = torch.from_numpy(image).unsqueeze(0)
202
  return image
203
 
 
204
  total_steps = num_epochs * len(dataset)
205
  current_step = 0
206
 
207
  self.training_jobs[job_id]["status"] = "training"
208
- log_msg = f"{datetime.now().strftime('%H:%M:%S')} - Iniciando treinamento real..."
209
- self.training_jobs[job_id]["logs"].append(log_msg)
210
- logger.info(log_msg)
211
 
212
  for epoch in range(num_epochs):
213
  for item in dataset:
214
  current_step += 1
215
 
 
216
  image = item["image"]
217
  caption = item["caption"]
 
 
218
  image_tensor = preprocess_image(image).to(self.device)
219
  if torch.cuda.is_available():
220
  image_tensor = image_tensor.half()
221
 
 
222
  with torch.no_grad():
223
  latents = vae.encode(image_tensor * 2 - 1).latent_dist.sample() * 0.18215
224
 
 
225
  inputs = tokenizer(caption, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
226
  input_ids = inputs.input_ids.to(self.device)
227
 
 
228
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (1,), device=self.device).long()
 
 
229
  noise = torch.randn_like(latents)
230
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
231
 
 
232
  encoder_hidden_states = text_encoder(input_ids)[0]
233
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
234
 
 
235
  loss = torch.nn.functional.mse_loss(noise_pred, noise)
 
 
236
  optimizer.zero_grad()
237
  loss.backward()
238
  optimizer.step()
239
 
240
- if torch.cuda.is_available():
241
- torch.cuda.empty_cache()
242
-
243
  progress = 30 + int((current_step / total_steps) * 60)
244
  self.training_jobs[job_id]["progress"] = min(progress, 90)
245
 
246
  if current_step % max(1, len(dataset)//2) == 0:
247
  log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
248
- self.training_jobs[job_id]["logs"].append(log_msg)
249
- logger.info(log_msg)
250
 
 
251
  self.training_jobs[job_id]["status"] = "saving"
252
  self.training_jobs[job_id]["progress"] = 95
253
 
254
  output_dir = f"./lora_models/{job_id}"
255
  os.makedirs(output_dir, exist_ok=True)
256
 
257
- # ✅ CORREÇÃO PRINCIPAL: SALVAR APENAS O ADAPTADOR LORA
258
  unet.save_pretrained(
259
  output_dir,
260
  safe_serialization=True,
261
  selected_adapters=["default"]
262
  )
263
 
 
264
  lora_config_dict = {
265
  "r": r,
266
  "lora_alpha": lora_alpha,
267
  "target_modules": ["to_k", "to_q", "to_v", "to_out.0"],
268
  "lora_dropout": lora_dropout,
269
  "bias": "none",
270
- "task_type": "CAUSAL_LM",
271
  "base_model_name": model_name,
272
  "training_info": {
273
  "num_epochs": num_epochs,
@@ -280,37 +288,56 @@ class LoRAImageTrainer:
280
  with open(f"{output_dir}/adapter_config.json", "w") as f:
281
  json.dump(lora_config_dict, f, indent=2)
282
 
 
283
  readme_content = f"""# LoRA Model - {job_id}
284
- Treinado com sucesso!
 
 
285
  Modelo Base: {model_name}
286
- Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  """
288
  with open(f"{output_dir}/README.md", "w") as f:
289
  f.write(readme_content)
290
 
 
291
  self.training_jobs[job_id]["status"] = "completed"
292
  self.training_jobs[job_id]["progress"] = 100
293
  self.training_jobs[job_id]["model_path"] = output_dir
294
  self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
295
- log_msg = f"{datetime.now().strftime('%H:%M:%S')} - ✅ Treinamento concluído! LoRA salvo em {output_dir}"
296
- self.training_jobs[job_id]["logs"].append(log_msg)
297
- logger.info(log_msg)
298
 
299
  except Exception as e:
300
- error_msg = f"Erro no treinamento: {str(e)}"
301
  logger.error(error_msg)
302
- if job_id in self.training_jobs:
303
- self.training_jobs[job_id]["status"] = "error"
304
- self.training_jobs[job_id]["error"] = error_msg
305
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ❌ {error_msg}")
306
 
307
  def start_training(self,
308
  model_name: str,
309
  image_files: List[str],
310
  captions: List[str],
311
  **kwargs) -> str:
 
 
312
  job_id = str(uuid.uuid4())
313
 
 
314
  dataset = self.prepare_image_dataset(image_files, captions, kwargs.get('resolution', 512))
315
 
316
  self.training_jobs[job_id] = {
@@ -326,6 +353,7 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
326
  "completed_at": None
327
  }
328
 
 
329
  thread = threading.Thread(
330
  target=self.real_lora_training,
331
  args=(job_id, model_name, dataset),
@@ -337,9 +365,11 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
337
  return job_id
338
 
339
  def get_training_status(self, job_id: str) -> Dict[str, Any]:
 
340
  return self.training_jobs.get(job_id, {"error": "Job não encontrado"})
341
 
342
  def list_trained_models(self) -> List[Dict[str, str]]:
 
343
  models = []
344
  lora_models_dir = Path("./lora_models")
345
 
@@ -351,6 +381,7 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
351
  try:
352
  with open(config_file, 'r') as f:
353
  config = json.load(f)
 
354
  models.append({
355
  "id": model_dir.name,
356
  "path": str(model_dir),
@@ -358,7 +389,7 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
358
  "r": config.get("r", "Unknown"),
359
  "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
360
  })
361
- except:
362
  models.append({
363
  "id": model_dir.name,
364
  "path": str(model_dir),
@@ -366,135 +397,531 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
366
  "r": "Unknown",
367
  "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
368
  })
 
369
  return models
370
 
371
  def create_download_zip(self, model_path: str) -> str:
 
372
  zip_path = f"{model_path}.zip"
 
373
  with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
374
  model_dir = Path(model_path)
375
  for file_path in model_dir.rglob('*'):
376
  if file_path.is_file():
377
  arcname = file_path.relative_to(model_dir)
378
  zipf.write(file_path, arcname)
 
379
  return zip_path
380
 
381
 
 
382
  trainer = LoRAImageTrainer()
383
 
384
  def create_gradio_interface():
 
 
 
385
  custom_css = """
 
386
  @media (max-width: 768px) {
387
- .gradio-container { padding: 8px !important; }
388
- .btn { width: 100% !important; padding: 12px !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  }
 
 
390
  .lora-header {
391
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
392
- color: white; padding: 20px; border-radius: 12px; margin-bottom: 20px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  }
394
- .status-indicator { padding: 4px 8px; border-radius: 6px; font-size: 12px; font-weight: 600; }
 
 
 
 
 
 
395
  .status-completed { background-color: #34d399; color: #065f46; }
396
  .status-error { background-color: #f87171; color: #991b1b; }
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  """
398
 
399
- def start_training_wrapper(model_name, files, captions_text, trigger_word, r, lora_alpha, num_epochs, learning_rate):
400
- if not files: return "❌ Erro: Nenhuma imagem enviada!"
401
- if len(files) < 3: return "❌ Forneça pelo menos 3 imagens!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
  try:
 
404
  image_files = [f.name for f in files]
405
- captions = [line.strip() for line in captions_text.split('\n') if line.strip()] if captions_text.strip() else []
 
 
 
 
 
 
406
  while len(captions) < len(files):
407
- captions.append(f"{trigger_word.strip() or 'training image'}, high quality photo" if trigger_word.strip() else f"training image {len(captions) + 1}, high quality photo")
 
 
 
 
 
408
  captions = captions[:len(files)]
409
 
410
  job_id = trainer.start_training(
411
- model_name=model_name,
412
- image_files=image_files,
413
  captions=captions,
414
- r=int(r),
415
- lora_alpha=int(lora_alpha),
416
- lora_dropout=0.0, # Fixado
417
- num_epochs=int(num_epochs),
418
  learning_rate=float(learning_rate),
419
- batch_size=1, # Fixado
420
- resolution=512 # Fixado
421
  )
422
- return f"✅ Treinamento iniciado! ID: {job_id}\n📊 Imagens: {len(files)}\nUse este ID para acompanhar o progresso."
 
 
423
  except Exception as e:
424
- return f"❌ Erro: {str(e)}"
425
 
426
  def check_status_wrapper(job_id):
427
- if not job_id.strip(): return "❌ Forneça um ID válido!"
 
 
 
428
  status = trainer.get_training_status(job_id.strip())
429
- if "error" in status: return "❌ Job não encontrado!"
430
 
431
- status_emoji = {'completed': '✅', 'error': '❌', 'training': '🏋️', 'queued': '⏳'}.get(status['status'], '📊')
432
- progress = status['progress']
433
- progress_bar = f'<div style="width:100%;background:#e5e7eb;border-radius:4px;overflow:hidden;margin:8px 0;"><div style="width:{progress}%;height:8px;background:linear-gradient(90deg,#3b82f6,#8b5cf6);border-radius:4px;"></div></div>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
- status_text = f"🆔 Job ID: {status['id']}\n{status_emoji} Status: {status['status'].upper()}\n⏳ Progresso: {progress}%\n{progress_bar}\n🤖 Modelo: {status['model_name']}\n🖼️ Imagens: {status.get('num_images','N/A')}\n📅 Criado: {status['created_at']}\n"
436
- if status['logs']: status_text += "\n📝 Logs:\n" + "\n".join([f"• {log}" for log in status['logs'][-5:]])
437
- if status['status'] == 'completed': status_text += f"\n✅ Concluído! Modelo salvo em: {status['model_path']}"
438
- elif status['status'] == 'error': status_text += f"\n❌ Erro: {status['error']}"
439
  return status_text
440
 
441
  def list_models_wrapper():
 
442
  models = trainer.list_trained_models()
443
- if not models: return "📭 Nenhum modelo encontrado."
444
- return "\n\n".join([f"🆔 {m['id']}\n🤖 {m['base_model']}\n📅 {m['created']}" for m in models])
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  def download_model_wrapper(job_id):
447
- if not job_id.strip(): return None, "❌ ID inválido!"
 
 
 
448
  status = trainer.get_training_status(job_id.strip())
449
- if status.get("status") != "completed": return None, "❌ Treinamento não concluído!"
 
 
 
 
 
 
450
  try:
451
- zip_path = trainer.create_download_zip(status['model_path'])
452
- return zip_path, "✅ ZIP criado! Clique acima para baixar."
 
 
 
453
  except Exception as e:
454
- return None, f"❌ Erro: {str(e)}"
455
 
456
- with gr.Blocks(title="🎨 LoRA Image Trainer", theme=gr.themes.Soft(), css=custom_css) as interface:
457
- gr.HTML('<div class="lora-header"><h1>🎨 LoRA Image Trainer</h1><p>Treine seu LoRA para imagens</p></div>')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
 
459
  with gr.Tabs():
460
- with gr.TabItem("🎯 Treinar"):
 
 
 
 
461
  with gr.Row():
462
- with gr.Column():
463
- model_dropdown = gr.Dropdown(trainer.get_available_models(), value="runwayml/stable-diffusion-v1-5", label="🤖 Modelo Base")
464
- image_files = gr.File(file_count="multiple", file_types=["image"], label="🖼️ Imagens")
465
- trigger_word = gr.Textbox(label="🏷️ Trigger Word", placeholder="ex: meuEstilo")
466
- captions_text = gr.Textbox(lines=4, placeholder="Legenda por linha...", label="📝 Legendas (Opcional)")
467
- with gr.Column():
468
- r = gr.Slider(4, 32, 8, step=4, label="r (Rank)")
469
- lora_alpha = gr.Slider(1, 32, 16, step=1, label="LoRA Alpha")
470
- num_epochs = gr.Slider(1, 10, 5, step=1, label="Épocas")
471
- learning_rate = gr.Slider(1e-5, 1e-3, 1e-4, step=1e-5, label="Taxa de Aprendizado")
472
- train_button = gr.Button("🚀 Iniciar Treinamento", variant="primary")
473
- train_output = gr.Textbox(label="📊 Resultado")
474
- train_button.click(start_training_wrapper, [model_dropdown, image_files, captions_text, trigger_word, r, lora_alpha, num_epochs, learning_rate], train_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
 
476
- with gr.TabItem("📊 Status"):
477
- job_id_input = gr.Textbox(label="🆔 ID do Job")
478
- status_button = gr.Button("🔍 Verificar Status")
479
- status_output = gr.Textbox(label="📈 Status", lines=10)
480
- status_button.click(check_status_wrapper, job_id_input, status_output)
 
 
 
 
481
 
482
- with gr.TabItem("📚 Download"):
483
- download_job_id = gr.Textbox(label="🆔 ID do Job")
484
- download_button = gr.Button("📦 Baixar Modelo")
485
- download_file = gr.File(label="📁 Arquivo ZIP")
486
- download_status = gr.Textbox(label="📊 Status")
487
- download_button.click(download_model_wrapper, download_job_id, [download_file, download_status])
488
-
489
- gr.Markdown("---\n<center>🎨 LoRA Image Trainer v1.0 | Treinamento Real de LoRA</center>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
 
491
  return interface
492
 
 
493
  if __name__ == "__main__":
 
 
 
 
494
  interface = create_gradio_interface()
 
 
495
  interface.launch(
496
  server_name="0.0.0.0",
497
  server_port=7860,
 
498
  show_error=True,
499
  quiet=False
500
  )
 
35
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  self.training_jobs = {}
37
  self.models_cache = {}
38
+
 
 
 
39
  def get_available_models(self) -> List[str]:
40
+ """Retorna lista de modelos base disponíveis para treinamento LoRA."""
41
  return [
42
  "runwayml/stable-diffusion-v1-5",
43
  "stabilityai/stable-diffusion-2-1",
44
  "CompVis/stable-diffusion-v1-4"
45
+ # XL removido por ser pesado demais para Spaces gratuitos
46
  ]
47
 
48
  def load_base_model(self, model_name: str):
49
+ """Carrega modelo base de difusão com otimizações para baixo uso de GPU."""
50
  try:
51
  if model_name in self.models_cache:
52
  return self.models_cache[model_name]
53
 
54
  logger.info(f"Carregando modelo base: {model_name}")
55
 
56
+ # Configurações para otimização de memória
57
  model_kwargs = {
58
  "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
59
  "use_safetensors": True,
60
  "variant": "fp16" if torch.cuda.is_available() else None,
61
+ "safety_checker": None, # Desativa verificador de segurança para economizar memória
62
  }
63
 
64
+ # Carregar pipeline completo
65
  pipeline = StableDiffusionPipeline.from_pretrained(
66
  model_name,
67
  **model_kwargs
 
69
 
70
  if torch.cuda.is_available():
71
  pipeline = pipeline.to(self.device)
72
+ # Habilitar attention slicing para economia de memória
73
  pipeline.enable_attention_slicing()
74
+ # Habilitar memory efficient attention se disponível
75
  try:
76
  pipeline.enable_xformers_memory_efficient_attention()
77
+ except Exception as e:
78
+ logger.warning("xformers não disponível, usando attention padrão")
 
79
 
80
+ # Cache do modelo
81
  self.models_cache[model_name] = pipeline
82
+
83
  return pipeline
84
 
85
  except Exception as e:
 
87
  raise e
88
 
89
  def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:
90
+ """Prepara dataset de imagens para treinamento."""
91
  dataset = []
92
 
93
  for img_path, caption in zip(image_files, captions):
94
  try:
95
+ # Carregar e redimensionar imagem
96
  image = Image.open(img_path).convert("RGB")
97
+
98
+ # Redimensionar mantendo aspect ratio
99
  image = self.resize_image(image, resolution)
100
 
101
  dataset.append({
 
111
  return dataset
112
 
113
  def resize_image(self, image: Image.Image, target_size: int) -> Image.Image:
114
+ """Redimensiona imagem mantendo aspect ratio e fazendo crop central se necessário."""
115
  width, height = image.size
116
 
117
+ # Calcular novo tamanho mantendo aspect ratio
118
  if width > height:
119
  new_width = target_size
120
  new_height = int((height * target_size) / width)
 
122
  new_height = target_size
123
  new_width = int((width * target_size) / height)
124
 
125
+ # Redimensionar
126
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
127
 
128
+ # Crop central para obter tamanho exato
129
  if new_width != target_size or new_height != target_size:
130
  left = (new_width - target_size) // 2
131
  top = (new_height - target_size) // 2
132
  right = left + target_size
133
  bottom = top + target_size
134
+
135
  image = image.crop((left, top, right, bottom))
136
 
137
  return image
 
140
  job_id: str,
141
  model_name: str,
142
  dataset: List[Dict],
143
+ r: int = 16,
144
+ lora_alpha: int = 32,
145
+ lora_dropout: float = 0.1,
146
+ num_epochs: int = 10,
147
  learning_rate: float = 1e-4,
148
  batch_size: int = 1,
149
  resolution: int = 512) -> None:
150
+ """TREINAMENTO REAL DE LoRA PARA IMAGENS - CORRIGIDO PARA DIFFUSERS + PEFT."""
151
+
152
  try:
153
+ # Atualizar status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  self.training_jobs[job_id]["status"] = "loading_model"
155
  self.training_jobs[job_id]["progress"] = 5
156
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Carregando modelo base: {model_name}")
157
+
158
+ # Carregar modelo base
 
159
  pipeline = self.load_base_model(model_name)
160
  unet = pipeline.unet
161
  text_encoder = pipeline.text_encoder
 
163
  tokenizer = pipeline.tokenizer
164
  scheduler = pipeline.scheduler
165
 
166
+ # Congelar parâmetros
167
  unet.requires_grad_(False)
168
  text_encoder.requires_grad_(False)
169
  vae.requires_grad_(False)
170
 
171
+ # Criar configuração LoRA
 
 
 
172
  lora_config = LoraConfig(
173
  r=r,
174
  lora_alpha=lora_alpha,
 
177
  bias="none"
178
  )
179
 
180
+ # Aplicar LoRA ao UNet manualmente, sem usar get_peft_model diretamente
181
  unet.add_adapter(lora_config, adapter_name="default")
 
 
 
 
 
 
 
182
 
183
+ # Ativar o adaptador
184
+ unet.set_adapter("default")
185
  unet.train()
186
  unet.to(self.device)
187
 
188
+ # Otimizador
189
+ optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
190
 
191
+ # Preparar scheduler para treinamento
192
  self.training_jobs[job_id]["status"] = "preparing_data"
193
  self.training_jobs[job_id]["progress"] = 20
194
 
195
+ # Normalização de imagem
196
  def preprocess_image(image):
197
  image = np.array(image).astype(np.float32) / 255.0
198
  image = image.transpose(2, 0, 1)
199
  image = torch.from_numpy(image).unsqueeze(0)
200
  return image
201
 
202
+ # Loop de treinamento real
203
  total_steps = num_epochs * len(dataset)
204
  current_step = 0
205
 
206
  self.training_jobs[job_id]["status"] = "training"
207
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Iniciando treinamento real...")
 
 
208
 
209
  for epoch in range(num_epochs):
210
  for item in dataset:
211
  current_step += 1
212
 
213
+ # Obter imagem e legenda
214
  image = item["image"]
215
  caption = item["caption"]
216
+
217
+ # Pré-processar imagem
218
  image_tensor = preprocess_image(image).to(self.device)
219
  if torch.cuda.is_available():
220
  image_tensor = image_tensor.half()
221
 
222
+ # Codificar imagem para latentes
223
  with torch.no_grad():
224
  latents = vae.encode(image_tensor * 2 - 1).latent_dist.sample() * 0.18215
225
 
226
+ # Tokenizar texto
227
  inputs = tokenizer(caption, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
228
  input_ids = inputs.input_ids.to(self.device)
229
 
230
+ # Gerar timesteps aleatórios
231
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (1,), device=self.device).long()
232
+
233
+ # Adicionar ruído aos latentes
234
  noise = torch.randn_like(latents)
235
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
236
 
237
+ # Forward pass
238
  encoder_hidden_states = text_encoder(input_ids)[0]
239
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
240
 
241
+ # Calcular perda
242
  loss = torch.nn.functional.mse_loss(noise_pred, noise)
243
+
244
+ # Backward pass
245
  optimizer.zero_grad()
246
  loss.backward()
247
  optimizer.step()
248
 
249
+ # Atualizar progresso
 
 
250
  progress = 30 + int((current_step / total_steps) * 60)
251
  self.training_jobs[job_id]["progress"] = min(progress, 90)
252
 
253
  if current_step % max(1, len(dataset)//2) == 0:
254
  log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
255
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - {log_msg}")
 
256
 
257
+ # Salvar LoRA
258
  self.training_jobs[job_id]["status"] = "saving"
259
  self.training_jobs[job_id]["progress"] = 95
260
 
261
  output_dir = f"./lora_models/{job_id}"
262
  os.makedirs(output_dir, exist_ok=True)
263
 
264
+ # ✅ ÚNICA ALTERAÇÃO: SALVAR APENAS OS ADAPTADORES LORA
265
  unet.save_pretrained(
266
  output_dir,
267
  safe_serialization=True,
268
  selected_adapters=["default"]
269
  )
270
 
271
+ # Criar adapter_config.json
272
  lora_config_dict = {
273
  "r": r,
274
  "lora_alpha": lora_alpha,
275
  "target_modules": ["to_k", "to_q", "to_v", "to_out.0"],
276
  "lora_dropout": lora_dropout,
277
  "bias": "none",
278
+ "task_type": "CAUSAL_LM", # Mantido por compatibilidade, mas não é usado
279
  "base_model_name": model_name,
280
  "training_info": {
281
  "num_epochs": num_epochs,
 
288
  with open(f"{output_dir}/adapter_config.json", "w") as f:
289
  json.dump(lora_config_dict, f, indent=2)
290
 
291
+ # README
292
  readme_content = f"""# LoRA Model - {job_id}
293
+
294
+ Informações do Treinamento
295
+
296
  Modelo Base: {model_name}
297
+ Rank (r): {r}
298
+ LoRA Alpha: {lora_alpha}
299
+ Dropout: {lora_dropout}
300
+ Épocas: {num_epochs}
301
+ Taxa de Aprendizado: {learning_rate}
302
+ Resolução: {resolution}x{resolution}
303
+ Número de Imagens: {len(dataset)}
304
+ Data de Treinamento: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
305
+
306
+ Como Usar
307
+
308
+ 1. Baixe os arquivos adapter_config.json e adapter_model.safetensors
309
+ 2. Carregue em sua ferramenta de geração de imagens favorita (ComfyUI, Automatic1111, etc.)
310
+ 3. Use o trigger word ou estilo aprendido durante o treinamento
311
  """
312
  with open(f"{output_dir}/README.md", "w") as f:
313
  f.write(readme_content)
314
 
315
+ # Finalizar
316
  self.training_jobs[job_id]["status"] = "completed"
317
  self.training_jobs[job_id]["progress"] = 100
318
  self.training_jobs[job_id]["model_path"] = output_dir
319
  self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
320
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ✅ Treinamento REAL concluído! LoRA salvo em {output_dir}")
321
+
322
+ logger.info(f"Treinamento LoRA REAL concluído para job {job_id}")
323
 
324
  except Exception as e:
325
+ error_msg = f"Erro REAL no treinamento: {str(e)}"
326
  logger.error(error_msg)
327
+ self.training_jobs[job_id]["status"] = "error"
328
+ self.training_jobs[job_id]["error"] = error_msg
329
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - ❌ {error_msg}")
 
330
 
331
  def start_training(self,
332
  model_name: str,
333
  image_files: List[str],
334
  captions: List[str],
335
  **kwargs) -> str:
336
+ """Inicia treinamento LoRA assíncrono."""
337
+
338
  job_id = str(uuid.uuid4())
339
 
340
+ # Preparar dataset
341
  dataset = self.prepare_image_dataset(image_files, captions, kwargs.get('resolution', 512))
342
 
343
  self.training_jobs[job_id] = {
 
353
  "completed_at": None
354
  }
355
 
356
+ # Iniciar treinamento em thread separada
357
  thread = threading.Thread(
358
  target=self.real_lora_training,
359
  args=(job_id, model_name, dataset),
 
365
  return job_id
366
 
367
  def get_training_status(self, job_id: str) -> Dict[str, Any]:
368
+ """Retorna status do treinamento."""
369
  return self.training_jobs.get(job_id, {"error": "Job não encontrado"})
370
 
371
  def list_trained_models(self) -> List[Dict[str, str]]:
372
+ """Lista modelos LoRA treinados."""
373
  models = []
374
  lora_models_dir = Path("./lora_models")
375
 
 
381
  try:
382
  with open(config_file, 'r') as f:
383
  config = json.load(f)
384
+
385
  models.append({
386
  "id": model_dir.name,
387
  "path": str(model_dir),
 
389
  "r": config.get("r", "Unknown"),
390
  "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
391
  })
392
+ except Exception as e:
393
  models.append({
394
  "id": model_dir.name,
395
  "path": str(model_dir),
 
397
  "r": "Unknown",
398
  "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
399
  })
400
+
401
  return models
402
 
403
  def create_download_zip(self, model_path: str) -> str:
404
+ """Cria um arquivo ZIP com os arquivos do modelo LoRA para download."""
405
  zip_path = f"{model_path}.zip"
406
+
407
  with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
408
  model_dir = Path(model_path)
409
  for file_path in model_dir.rglob('*'):
410
  if file_path.is_file():
411
  arcname = file_path.relative_to(model_dir)
412
  zipf.write(file_path, arcname)
413
+
414
  return zip_path
415
 
416
 
417
+ # Instância global do trainer
418
  trainer = LoRAImageTrainer()
419
 
420
  def create_gradio_interface():
421
+ """Cria interface Gradio para a ferramenta LoRA de geração de imagens."""
422
+
423
+ # CSS personalizado para responsividade móvel
424
  custom_css = """
425
+ /* Mobile-first responsive design */
426
  @media (max-width: 768px) {
427
+ .gradio-container {
428
+ padding: 8px !important;
429
+ margin: 0 !important;
430
+ }
431
+
432
+ .tab-nav {
433
+ flex-wrap: wrap !important;
434
+ gap: 4px !important;
435
+ }
436
+
437
+ .tab-nav button {
438
+ font-size: 14px !important;
439
+ padding: 8px 12px !important;
440
+ min-width: auto !important;
441
+ flex: 1 1 auto !important;
442
+ }
443
+
444
+ .form-container {
445
+ padding: 12px !important;
446
+ }
447
+
448
+ .btn {
449
+ width: 100% !important;
450
+ padding: 12px !important;
451
+ font-size: 16px !important;
452
+ margin-bottom: 8px !important;
453
+ min-height: 44px !important;
454
+ }
455
+
456
+ .textbox textarea {
457
+ font-size: 16px !important;
458
+ min-height: 120px !important;
459
+ }
460
+
461
+ .dropdown select {
462
+ font-size: 16px !important;
463
+ padding: 12px !important;
464
+ }
465
+
466
+ .output-text {
467
+ font-size: 14px !important;
468
+ line-height: 1.5 !important;
469
+ }
470
+
471
+ .column {
472
+ margin-bottom: 16px !important;
473
+ }
474
+
475
+ .file-upload {
476
+ min-height: 100px !important;
477
+ }
478
  }
479
+
480
+ /* Enhanced visual styles */
481
  .lora-header {
482
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
483
+ color: white;
484
+ padding: 20px;
485
+ border-radius: 12px;
486
+ margin-bottom: 20px;
487
+ text-align: center;
488
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
489
+ }
490
+
491
+ .status-indicator {
492
+ display: inline-block;
493
+ padding: 4px 8px;
494
+ border-radius: 6px;
495
+ font-size: 12px;
496
+ font-weight: 600;
497
+ text-transform: uppercase;
498
+ letter-spacing: 0.5px;
499
+ margin-right: 8px;
500
  }
501
+
502
+ .status-queued { background-color: #fbbf24; color: #92400e; }
503
+ .status-loading_model { background-color: #60a5fa; color: #1e40af; }
504
+ .status-preparing_lora { background-color: #8b5cf6; color: #5b21b6; }
505
+ .status-preparing_data { background-color: #06b6d4; color: #0e7490; }
506
+ .status-training { background-color: #a78bfa; color: #5b21b6; }
507
+ .status-saving { background-color: #f59e0b; color: #92400e; }
508
  .status-completed { background-color: #34d399; color: #065f46; }
509
  .status-error { background-color: #f87171; color: #991b1b; }
510
+
511
+ /* Touch device optimizations */
512
+ @media (hover: none) and (pointer: coarse) {
513
+ .btn {
514
+ min-height: 44px !important;
515
+ min-width: 44px !important;
516
+ }
517
+
518
+ .tab-nav button {
519
+ min-height: 44px !important;
520
+ min-width: 44px !important;
521
+ }
522
+ }
523
  """
524
 
525
+ def process_images_and_captions(files, captions_text):
526
+ """Processa imagens e legendas enviadas pelo usuário."""
527
+ if not files:
528
+ return "❌ Erro: Nenhuma imagem foi enviada!"
529
+
530
+ # Processar legendas
531
+ captions = []
532
+ if captions_text.strip():
533
+ captions = [line.strip() for line in captions_text.split('\n') if line.strip()]
534
+
535
+ # Se não há legendas suficientes, usar legendas padrão
536
+ while len(captions) < len(files):
537
+ captions.append(f"training image {len(captions) + 1}")
538
+
539
+ # Truncar legendas se houver mais que imagens
540
+ captions = captions[:len(files)]
541
+
542
+ return files, captions
543
+
544
+ def start_training_wrapper(model_name, files, captions_text, trigger_word, r, lora_alpha, lora_dropout,
545
+ num_epochs, learning_rate, batch_size, resolution):
546
+ """Wrapper para iniciar treinamento via Gradio."""
547
+
548
+ if not files:
549
+ return "❌ Erro: Nenhuma imagem foi enviada para treinamento!"
550
+
551
+ if len(files) < 3:
552
+ return "❌ Erro: Forneça pelo menos 3 imagens para treinamento!"
553
 
554
  try:
555
+ # Processar imagens e legendas
556
  image_files = [f.name for f in files]
557
+
558
+ # Processar legendas
559
+ captions = []
560
+ if captions_text.strip():
561
+ captions = [line.strip() for line in captions_text.split('\n') if line.strip()]
562
+
563
+ # Se não há legendas suficientes, usar trigger word + descrição padrão
564
  while len(captions) < len(files):
565
+ if trigger_word.strip():
566
+ captions.append(f"{trigger_word.strip()}, high quality photo")
567
+ else:
568
+ captions.append(f"training image {len(captions) + 1}, high quality photo")
569
+
570
+ # Truncar legendas se houver mais que imagens
571
  captions = captions[:len(files)]
572
 
573
  job_id = trainer.start_training(
574
+ model_name=model_name,
575
+ image_files=image_files,
576
  captions=captions,
577
+ r=int(r),
578
+ lora_alpha=int(lora_alpha),
579
+ lora_dropout=float(lora_dropout),
580
+ num_epochs=int(num_epochs),
581
  learning_rate=float(learning_rate),
582
+ batch_size=int(batch_size),
583
+ resolution=int(resolution)
584
  )
585
+
586
+ return f"✅ Treinamento REAL iniciado! ID do Job: {job_id}\n\n📊 Imagens: {len(files)}\n🏷️ Trigger Word: {trigger_word or 'Nenhuma'}\n\nUse o ID acima para verificar o progresso na aba 'Status do Treinamento'."
587
+
588
  except Exception as e:
589
+ return f"❌ Erro ao iniciar treinamento: {str(e)}"
590
 
591
  def check_status_wrapper(job_id):
592
+ """Wrapper para verificar status via Gradio."""
593
+ if not job_id.strip():
594
+ return "❌ Erro: Forneça um ID de job válido!"
595
+
596
  status = trainer.get_training_status(job_id.strip())
 
597
 
598
+ if "error" in status and status["error"] == "Job não encontrado":
599
+ return "❌ Job não encontrado! Verifique o ID."
600
+
601
+ # Criar indicador visual de status
602
+ status_class = f"status-{status['status']}"
603
+ status_emoji = {
604
+ 'queued': '⏳',
605
+ 'loading_model': '📥',
606
+ 'preparing_lora': '⚙️',
607
+ 'preparing_data': '📊',
608
+ 'training': '🏋️',
609
+ 'saving': '💾',
610
+ 'completed': '✅',
611
+ 'error': '❌'
612
+ }.get(status['status'], '📊')
613
+
614
+ # Barra de progresso visual
615
+ progress = status['progress']
616
+ progress_bar = f"""
617
+ <div style="width: 100%; background-color: #e5e7eb; border-radius: 4px; overflow: hidden; margin: 8px 0;">
618
+ <div style="width: {progress}%; height: 8px; background: linear-gradient(90deg, #3b82f6, #8b5cf6); transition: width 0.3s ease; border-radius: 4px;"></div>
619
+ </div>
620
+ """
621
+
622
+ status_text = f"""
623
+
624
+ 📊 Status do Treinamento LoRA
625
+
626
+ 🆔 Job ID: {status['id']}
627
+ {status_emoji} Status: <span class="{status_class}">{status['status'].upper().replace('_', ' ')}</span>
628
+ ⏳ Progresso: {status['progress']}%
629
+
630
+ {progress_bar}
631
+
632
+ 🤖 Modelo Base: {status['model_name']}
633
+ 🖼️ Imagens: {status.get('num_images', 'N/A')}
634
+ 📅 Criado em: {status['created_at']}
635
+
636
+ """
637
+
638
+ if status['logs']:
639
+ status_text += "📝 **Logs Recentes:**\n"
640
+ for log in status['logs'][-5:]: # Últimos 5 logs
641
+ status_text += f"• {log}\n"
642
+
643
+ if status['status'] == 'completed':
644
+ status_text += f"\n✅ **Treinamento Concluído!**\n📁 **Modelo salvo em:** {status['model_path']}"
645
+ status_text += f"\n⏰ **Concluído em:** {status['completed_at']}"
646
+ status_text += f"\n\n💡 **Próximos passos:** Vá para a aba 'Modelos Treinados' para baixar seu LoRA!"
647
+ elif status['status'] == 'error':
648
+ status_text += f"\n❌ **Erro:** {status['error']}"
649
 
 
 
 
 
650
  return status_text
651
 
652
  def list_models_wrapper():
653
+ """Wrapper para listar modelos via Gradio."""
654
  models = trainer.list_trained_models()
655
+
656
+ if not models:
657
+ return "📭 Nenhum modelo LoRA treinado encontrado."
658
+
659
+ models_text = "📚 **Modelos LoRA Treinados:**\n\n"
660
+ for model in models:
661
+ models_text += f"🆔 **ID:** {model['id']}\n"
662
+ models_text += f"🤖 **Modelo Base:** {model['base_model']}\n"
663
+ models_text += f"📊 **Rank (r):** {model['r']}\n"
664
+ models_text += f"📁 **Caminho:** {model['path']}\n"
665
+ models_text += f"📅 **Criado:** {model['created']}\n\n"
666
+ models_text += "---\n\n"
667
+
668
+ return models_text
669
 
670
  def download_model_wrapper(job_id):
671
+ """Wrapper para preparar download do modelo."""
672
+ if not job_id.strip():
673
+ return None, "❌ Erro: Forneça um ID de job válido!"
674
+
675
  status = trainer.get_training_status(job_id.strip())
676
+
677
+ if "error" in status and status["error"] == "Job não encontrado":
678
+ return None, "❌ Job não encontrado! Verifique o ID."
679
+
680
+ if status['status'] != 'completed':
681
+ return None, f"❌ Treinamento ainda não foi concluído. Status atual: {status['status']}"
682
+
683
  try:
684
+ model_path = status['model_path']
685
+ zip_path = trainer.create_download_zip(model_path)
686
+
687
+ return zip_path, f"✅ Arquivo ZIP criado com sucesso! Clique no link acima para baixar."
688
+
689
  except Exception as e:
690
+ return None, f"❌ Erro ao criar arquivo de download: {str(e)}"
691
 
692
+ # Interface Gradio
693
+ with gr.Blocks(
694
+ title="🎨 LoRA Image Trainer - Criador e Treinador de LoRA para Imagens",
695
+ theme=gr.themes.Soft(),
696
+ css=custom_css
697
+ ) as interface:
698
+
699
+ gr.HTML("""
700
+ <div class="lora-header">
701
+ <h1>🎨 LoRA Image Trainer</h1>
702
+ <p>Criador e Treinador de LoRA para Geração de Imagens</p>
703
+ <p style="font-size: 0.9em; opacity: 0.9; margin-top: 8px;">
704
+ Ferramenta otimizada para baixo uso de GPU, compatível com dispositivos móveis
705
+ </p>
706
+ </div>
707
+ """)
708
 
709
  with gr.Tabs():
710
+
711
+ # Aba de Treinamento
712
+ with gr.TabItem("🎯 Treinar LoRA"):
713
+ gr.Markdown("### Configurar e Iniciar Treinamento LoRA para Imagens")
714
+
715
  with gr.Row():
716
+ with gr.Column(scale=2):
717
+ model_dropdown = gr.Dropdown(
718
+ choices=trainer.get_available_models(),
719
+ value="runwayml/stable-diffusion-v1-5",
720
+ label="🤖 Modelo Base",
721
+ )
722
+
723
+ image_files = gr.File(
724
+ file_count="multiple",
725
+ file_types=["image"],
726
+ label="🖼️ Imagens de Treinamento",
727
+ )
728
+
729
+ trigger_word = gr.Textbox(
730
+ label="🏷️ Trigger Word (Opcional)",
731
+ placeholder="ex: meuEstilo, minhaPersonagem, etc.",
732
+ )
733
+
734
+ captions_text = gr.Textbox(
735
+ lines=8,
736
+ placeholder="Digite uma legenda por linha (opcional)...\n\nExemplo:\nmeuEstilo, retrato de uma mulher\nmeuEstilo, homem sorrindo\nmeuEstilo, paisagem urbana\n\nSe deixar vazio, usará a trigger word + 'high quality photo'",
737
+ label="📝 Legendas das Imagens (Opcional)",
738
+ )
739
+
740
+ with gr.Column(scale=1):
741
+ gr.Markdown("### ⚙️ Parâmetros LoRA")
742
+
743
+ r = gr.Slider(
744
+ minimum=4, maximum=64, value=8, step=4, # reduzido max para 64
745
+ label="r (Rank)",
746
+ )
747
+
748
+ lora_alpha = gr.Slider(
749
+ minimum=1, maximum=64, value=16, step=1, # reduzido max para 64
750
+ label="LoRA Alpha",
751
+ )
752
+
753
+ lora_dropout = gr.Slider(
754
+ minimum=0.0, maximum=0.5, value=0.0, step=0.05, # dropout 0 para mais estabilidade
755
+ label="LoRA Dropout",
756
+ )
757
+
758
+ gr.Markdown("### 🏋️ Parâmetros de Treinamento")
759
+
760
+ num_epochs = gr.Slider(
761
+ minimum=5, maximum=20, value=10, step=5, # reduzido max para 20
762
+ label="Épocas",
763
+ )
764
+
765
+ learning_rate = gr.Slider(
766
+ minimum=1e-5, maximum=5e-4, value=1e-4, step=1e-5, # reduzido max
767
+ label="Taxa de Aprendizado",
768
+ )
769
+
770
+ batch_size = gr.Slider(
771
+ minimum=1, maximum=1, value=1, step=1, # fixado em 1 para Spaces
772
+ label="Batch Size",
773
+ )
774
+
775
+ resolution = gr.Dropdown(
776
+ choices=[512], # fixado em 512 para garantir funcionamento em GPU limitada
777
+ value=512,
778
+ label="Resolução",
779
+ )
780
 
781
+ train_button = gr.Button("🚀 Iniciar Treinamento LoRA", variant="primary", size="lg")
782
+ train_output = gr.Textbox(label="📊 Resultado", lines=5)
783
+
784
+ train_button.click(
785
+ start_training_wrapper,
786
+ inputs=[model_dropdown, image_files, captions_text, trigger_word, r, lora_alpha, lora_dropout,
787
+ num_epochs, learning_rate, batch_size, resolution],
788
+ outputs=train_output
789
+ )
790
 
791
+ # Aba de Status
792
+ with gr.TabItem("📊 Status do Treinamento"):
793
+ gr.Markdown("### Verificar Progresso do Treinamento")
794
+
795
+ job_id_input = gr.Textbox(
796
+ label="🆔 ID do Job",
797
+ placeholder="Cole aqui o ID do job de treinamento...",
798
+ )
799
+
800
+ status_button = gr.Button("🔍 Verificar Status", variant="secondary")
801
+ status_output = gr.Textbox(label="📈 Status", lines=12)
802
+
803
+ status_button.click(
804
+ check_status_wrapper,
805
+ inputs=job_id_input,
806
+ outputs=status_output
807
+ )
808
+
809
+ gr.Markdown("💡 **Dica:** Atualize o status regularmente para acompanhar o progresso do treinamento.")
810
+
811
+ # Aba de Modelos e Download
812
+ with gr.TabItem("📚 Modelos e Download"):
813
+ gr.Markdown("### Visualizar e Baixar Modelos LoRA Treinados")
814
+
815
+ with gr.Row():
816
+ with gr.Column(scale=1):
817
+ list_button = gr.Button("📋 Listar Modelos", variant="secondary")
818
+ models_output = gr.Textbox(label="📚 Modelos Disponíveis", lines=10)
819
+
820
+ list_button.click(
821
+ list_models_wrapper,
822
+ outputs=models_output
823
+ )
824
+
825
+ with gr.Column(scale=1):
826
+ gr.Markdown("#### 💾 Download de Modelo")
827
+
828
+ download_job_id = gr.Textbox(
829
+ label="🆔 ID do Job para Download",
830
+ placeholder="Cole o ID do job concluído...",
831
+ )
832
+
833
+ download_button = gr.Button("📦 Preparar Download", variant="primary")
834
+ download_file = gr.File(label="📁 Arquivo para Download")
835
+ download_status = gr.Textbox(label="📊 Status do Download", lines=3)
836
+
837
+ download_button.click(
838
+ download_model_wrapper,
839
+ inputs=download_job_id,
840
+ outputs=[download_file, download_status]
841
+ )
842
+
843
+ # Aba de Informações
844
+ with gr.TabItem("ℹ️ Sobre"):
845
+ gr.Markdown("""
846
+ ### 🎯 Sobre o LoRA Image Trainer
847
+
848
+ Esta ferramenta foi desenvolvida para democratizar o acesso ao treinamento de modelos LoRA para geração de imagens,
849
+ permitindo que qualquer pessoa possa criar adaptações personalizadas de modelos de difusão (como Stable Diffusion)
850
+ sem a necessidade de hardware especializado.
851
+
852
+ #### ✨ Características Principais:
853
+
854
+ - **🔋 Otimizado para Baixa GPU**: Utiliza técnicas como mixed precision, gradient checkpointing e configurações otimizadas
855
+ - **📱 Compatível com Móveis**: Interface responsiva que funciona em smartphones e tablets
856
+ - **⚡ Rápido e Eficiente**: Treinamento otimizado com bibliotecas Diffusers e PEFT do Hugging Face
857
+ - **🎛️ Configurável**: Controle total sobre parâmetros LoRA e de treinamento
858
+ - **☁️ Pronto para Deploy**: Facilmente implantável no Hugging Face Spaces
859
+ - **🎨 Focado em Imagens**: Especificamente projetado para modelos de difusão e geração de imagens
860
+
861
+ #### 🛠️ Tecnologias Utilizadas:
862
+
863
+ - **Hugging Face Diffusers**: Para modelos de difusão e pipeline de treinamento
864
+ - **PEFT (Parameter-Efficient Fine-Tuning)**: Para treinamento eficiente de LoRA
865
+ - **PyTorch**: Framework de deep learning
866
+ - **Gradio**: Interface web interativa e responsiva
867
+ - **LoRA (Low-Rank Adaptation)**: Técnica de fine-tuning eficiente para modelos de difusão
868
+
869
+ #### 📖 Como Usar:
870
+
871
+ 1. **Prepare suas imagens**: Colete 3-50 imagens de alta qualidade do estilo/conceito que deseja treinar
872
+ 2. **Escolha um modelo base** na aba "Treinar LoRA" (recomendado: Stable Diffusion 1.5)
873
+ 3. **Faça upload das imagens** e defina uma trigger word (palavra-chave)
874
+ 4. **Configure os parâmetros** conforme necessário (valores padrão funcionam bem)
875
+ 5. **Inicie o treinamento** e anote o ID do job
876
+ 6. **Acompanhe o progresso** na aba "Status do Treinamento"
877
+ 7. **Baixe seu LoRA** na aba "Modelos e Download" quando concluído
878
+ 8. **Use em suas ferramentas favoritas** (ComfyUI, Automatic1111, etc.)
879
+
880
+ #### 💡 Dicas para Melhores Resultados:
881
+
882
+ - **Qualidade > Quantidade**: 10-20 imagens de alta qualidade são melhores que 50 imagens ruins
883
+ - **Consistência**: Use imagens com estilo/conceito consistente
884
+ - **Resolução**: Para GPUs com pouca VRAM, use resolução 512x512
885
+ - **Trigger Word**: Escolha uma palavra única e fácil de lembrar
886
+ - **Legendas**: Descreva o que há nas imagens para melhor controle
887
+ - **Parâmetros**: Para iniciantes, use os valores padrão
888
+
889
+ #### 🎮 Compatibilidade:
890
+
891
+ Os LoRAs gerados são compatíveis com:
892
+ - **ComfyUI**: Carregue os arquivos .safetensors
893
+ - **Automatic1111**: Coloque na pasta models/Lora
894
+ - **SeaArt**: Faça upload do modelo
895
+ - **Outras ferramentas**: Qualquer ferramenta que suporte LoRA para Stable Diffusion
896
+
897
+ ---
898
+
899
+ **Desenvolvido com ❤️ para a comunidade de IA e arte digital**
900
+ """)
901
+
902
+ # Footer
903
+ gr.Markdown("""
904
+ ---
905
+ <div style="text-align: center; color: #666; font-size: 0.9em;">
906
+ 🎨 LoRA Image Trainer v1.0 | Otimizado para Baixa GPU | Compatível com Dispositivos Móveis
907
+ </div>
908
+ """)
909
 
910
  return interface
911
 
912
+ # Criar e configurar interface
913
  if __name__ == "__main__":
914
+ # Criar diretórios necessários
915
+ os.makedirs("./lora_models", exist_ok=True)
916
+
917
+ # Configurar interface
918
  interface = create_gradio_interface()
919
+
920
+ # Lançar aplicação
921
  interface.launch(
922
  server_name="0.0.0.0",
923
  server_port=7860,
924
+ share=False,
925
  show_error=True,
926
  quiet=False
927
  )