File size: 2,505 Bytes
f3b067f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
import os
import torch
from diffusers import StableDiffusionPipeline
from peft import LoraConfig, get_peft_model
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

# Dataset customizado
class ImageDataset(Dataset):
    def __init__(self, folder, size=512):
        self.files = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith((".png", ".jpg", ".jpeg"))]
        self.transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        return self.transform(img)

def main(args):
    # Carrega modelo base
    model_id = "runwayml/stable-diffusion-v1-5"
    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")

    # Configuração do LoRA
    lora_config = LoraConfig(
        r=args.rank,
        lora_alpha=16,
        target_modules=["to_q", "to_v"],
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
    )
    pipe.unet = get_peft_model(pipe.unet, lora_config)

    # Dataset
    dataset = ImageDataset(args.images_dir)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

    # Otimizador
    optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=args.learning_rate)

    # Loop de treino
    for epoch in range(args.num_epochs):
        for batch in dataloader:
            batch = batch.to("cuda")
            noise = torch.randn_like(batch)
            optimizer.zero_grad()
            loss = pipe.unet(batch, noise)["loss"]
            loss.backward()
            optimizer.step()
        print(f"✅ Epoch {epoch+1}/{args.num_epochs} finalizado.")

    # Salvar LoRA
    os.makedirs(args.output_dir, exist_ok=True)
    torch.save(pipe.unet.state_dict(), os.path.join(args.output_dir, "lora.safetensors"))
    print("✅ Treinamento concluído. Arquivo salvo em lora.safetensors")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--images_dir", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--learning_rate", type=float, default=1e-4)
    parser.add_argument("--num_epochs", type=int, default=10)
    parser.add_argument("--rank", type=int, default=4)
    args = parser.parse_args()
    main(args)