# train_lora.py import os import torch import argparse from diffusers import StableDiffusionPipeline, DDPMScheduler from peft import LoraConfig, get_peft_model from accelerate import Accelerator from torchvision import transforms from PIL import Image import glob def main(args): # Inicializa o Accelerator accelerator = Accelerator( mixed_precision="fp16" if args.mixed_precision else None ) print(f"🚀 Carregando modelo: {args.model_name}") try: pipe = StableDiffusionPipeline.from_pretrained( args.model_name, torch_dtype=torch.float16 if args.mixed_precision else torch.float32 ) except Exception as e: print(f"❌ Falha ao carregar modelo: {e}") return # Extrai componentes unet = pipe.unet tokenizer = pipe.tokenizer text_encoder = pipe.text_encoder vae = pipe.vae noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config) # Configura LoRA lora_config = LoraConfig( r=args.lora_rank, lora_alpha=args.lora_alpha, target_modules=["to_q", "to_v", "to_k", "to_out.0"], lora_dropout=0.0, bias="none" ) unet = get_peft_model(unet, lora_config) unet.print_trainable_parameters() # Mostra % de parâmetros treináveis # Transformações de imagem transform = transforms.Compose([ transforms.Resize(512), transforms.CenterCrop(512), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) # === Carrega dataset === image_paths = [] for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.webp"]: image_paths.extend(glob.glob(os.path.join(args.dataset_dir, ext))) if len(image_paths) == 0: print("❌ Nenhuma imagem encontrada no diretório!") return print(f"✅ {len(image_paths)} imagens encontradas. Carregando legendas...") captions = [] valid_images = [] for img_path in image_paths: txt_path = os.path.splitext(img_path)[0] + ".txt" if os.path.exists(txt_path): with open(txt_path, "r", encoding="utf-8") as f: caption = f.read().strip() else: caption = "person" captions.append(caption) valid_images.append(img_path) # Dataset PyTorch class SimpleDataset(torch.utils.data.Dataset): def __init__(self, image_list, caption_list, transform): self.images = image_list self.captions = caption_list self.transform = transform def __len__(self): return len(self.images) def __getitem__(self, idx): image = Image.open(self.images[idx]).convert("RGB") image = self.transform(image) return {"pixel_values": image, "input_ids": self.captions[idx]} dataset = SimpleDataset(valid_images, captions, transform) dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=True ) # Otimizador optimizer = torch.optim.AdamW(unet.parameters(), lr=args.learning_rate) # Prepara com Accelerator unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader) # Coloca VAE e Text Encoder em modo de avaliação (só UNET é treinado) vae.eval() text_encoder.eval() # Treinamento unet.train() step = 0 for epoch in range(args.num_epochs): for batch in dataloader: with accelerator.accumulate(unet): # Gera latents pixel_values = batch["pixel_values"] latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215 # Adiciona ruído noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Codifica texto inputs = tokenizer( batch["input_ids"], max_length=77, padding="max_length", truncation=True, return_tensors="pt" ).to(latents.device) encoder_hidden_states = text_encoder(**inputs)[0] # Predição de ruído noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample loss = torch.nn.functional.mse_loss(noise_pred, noise) # Backpropagation accelerator.backward(loss) optimizer.step() optimizer.zero_grad() step += 1 print(f"Epoch {epoch+1}/{args.num_epochs} - Loss: {loss.item():.4f}") # Salva modelo LoRA accelerator.wait_for_everyone() if accelerator.is_main_process: output_dir = args.output_dir unwrapped_unet = accelerator.unwrap_model(unet) unwrapped_unet.save_pretrained(output_dir) print(f"✅ Modelo LoRA salvo em: {output_dir}") # Opcional: salva também como safetensors from peft import save_model save_model(unwrapped_unet, output_dir) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Treina um modelo LoRA para Stable Diffusion") parser.add_argument("--model_name", type=str, default="runwayml/stable-diffusion-v1-5", help="Modelo base do HF") parser.add_argument("--dataset_dir", type=str, required=True, help="Pasta com imagens e .txt") parser.add_argument("--output_dir", type=str, default="lora-output", help="Onde salvar o LoRA") parser.add_argument("--lora_rank", type=int, default=4, help="Rank LoRA (4-64)") parser.add_argument("--lora_alpha", type=int, default=32, help="Alpha LoRA") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Taxa de aprendizado") parser.add_argument("--num_epochs", type=int, default=10, help="Número de épocas") parser.add_argument("--batch_size", type=int, default=1, help="Batch size") parser.add_argument("--mixed_precision", action="store_true", help="Usa FP16") args = parser.parse_args() main(args)