import argparse import os import torch from diffusers import StableDiffusionPipeline from peft import LoraConfig, get_peft_model from torch.utils.data import Dataset, DataLoader from PIL import Image from torchvision import transforms # Dataset customizado class ImageDataset(Dataset): def __init__(self, folder, size=512): self.files = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith((".png", ".jpg", ".jpeg"))] self.transform = transforms.Compose([ transforms.Resize((size, size)), transforms.ToTensor() ]) def __len__(self): return len(self.files) def __getitem__(self, idx): img = Image.open(self.files[idx]).convert("RGB") return self.transform(img) def main(args): # Carrega modelo base model_id = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") # Configuração do LoRA lora_config = LoraConfig( r=args.rank, lora_alpha=16, target_modules=["to_q", "to_v"], lora_dropout=0.1, bias="none", task_type="CAUSAL_LM", ) pipe.unet = get_peft_model(pipe.unet, lora_config) # Dataset dataset = ImageDataset(args.images_dir) dataloader = DataLoader(dataset, batch_size=1, shuffle=True) # Otimizador optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=args.learning_rate) # Loop de treino for epoch in range(args.num_epochs): for batch in dataloader: batch = batch.to("cuda") noise = torch.randn_like(batch) optimizer.zero_grad() loss = pipe.unet(batch, noise)["loss"] loss.backward() optimizer.step() print(f"✅ Epoch {epoch+1}/{args.num_epochs} finalizado.") # Salvar LoRA os.makedirs(args.output_dir, exist_ok=True) torch.save(pipe.unet.state_dict(), os.path.join(args.output_dir, "lora.safetensors")) print("✅ Treinamento concluído. Arquivo salvo em lora.safetensors") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--images_dir", type=str, required=True) parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--num_epochs", type=int, default=10) parser.add_argument("--rank", type=int, default=4) args = parser.parse_args() main(args)