File size: 6,278 Bytes
2a37b4f
7865b76
2a37b4f
06d0e1e
 
7865b76
2a37b4f
 
 
 
7865b76
dc3cfdb
7865b76
dc3cfdb
06d0e1e
dc3cfdb
06d0e1e
2a37b4f
dc3cfdb
 
 
 
 
 
 
 
 
 
 
 
2a37b4f
 
 
06d0e1e
2a37b4f
7865b76
 
 
 
2a37b4f
7865b76
 
 
 
dc3cfdb
2a37b4f
dc3cfdb
2a37b4f
 
 
 
 
 
 
dc3cfdb
 
 
 
 
 
 
 
 
 
2a37b4f
 
 
 
 
 
 
dc3cfdb
2a37b4f
dc3cfdb
 
2a37b4f
 
dc3cfdb
2a37b4f
dc3cfdb
 
 
2a37b4f
 
 
dc3cfdb
2a37b4f
 
dc3cfdb
2a37b4f
dc3cfdb
2a37b4f
 
dc3cfdb
 
 
 
 
2a37b4f
 
7865b76
2a37b4f
dc3cfdb
 
 
 
 
 
2a37b4f
 
 
dc3cfdb
7865b76
2a37b4f
 
dc3cfdb
 
06d0e1e
 
dc3cfdb
2a37b4f
 
06d0e1e
 
2a37b4f
dc3cfdb
 
2a37b4f
 
dc3cfdb
2a37b4f
 
dc3cfdb
 
2a37b4f
dc3cfdb
2a37b4f
 
 
dc3cfdb
2a37b4f
 
 
dc3cfdb
2a37b4f
dc3cfdb
 
 
2a37b4f
 
dc3cfdb
2a37b4f
dc3cfdb
 
 
 
 
 
 
7865b76
 
dc3cfdb
 
 
 
 
 
 
 
 
 
 
7865b76
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# train_lora.py
import os
import torch
import argparse
from diffusers import StableDiffusionPipeline, DDPMScheduler
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator
from torchvision import transforms
from PIL import Image
import glob


def main(args):
    # Inicializa o Accelerator
    accelerator = Accelerator(
        mixed_precision="fp16" if args.mixed_precision else None
    )

    print(f"🚀 Carregando modelo: {args.model_name}")
    try:
        pipe = StableDiffusionPipeline.from_pretrained(
            args.model_name,
            torch_dtype=torch.float16 if args.mixed_precision else torch.float32
        )
    except Exception as e:
        print(f"❌ Falha ao carregar modelo: {e}")
        return

    # Extrai componentes
    unet = pipe.unet
    tokenizer = pipe.tokenizer
    text_encoder = pipe.text_encoder
    vae = pipe.vae
    noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config)

    # Configura LoRA
    lora_config = LoraConfig(
        r=args.lora_rank,
        lora_alpha=args.lora_alpha,
        target_modules=["to_q", "to_v", "to_k", "to_out.0"],
        lora_dropout=0.0,
        bias="none"
    )
    unet = get_peft_model(unet, lora_config)
    unet.print_trainable_parameters()  # Mostra % de parâmetros treináveis

    # Transformações de imagem
    transform = transforms.Compose([
        transforms.Resize(512),
        transforms.CenterCrop(512),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

    # === Carrega dataset ===
    image_paths = []
    for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.webp"]:
        image_paths.extend(glob.glob(os.path.join(args.dataset_dir, ext)))
    
    if len(image_paths) == 0:
        print("❌ Nenhuma imagem encontrada no diretório!")
        return

    print(f"✅ {len(image_paths)} imagens encontradas. Carregando legendas...")

    captions = []
    valid_images = []
    for img_path in image_paths:
        txt_path = os.path.splitext(img_path)[0] + ".txt"
        if os.path.exists(txt_path):
            with open(txt_path, "r", encoding="utf-8") as f:
                caption = f.read().strip()
        else:
            caption = "person"
        captions.append(caption)
        valid_images.append(img_path)

    # Dataset PyTorch
    class SimpleDataset(torch.utils.data.Dataset):
        def __init__(self, image_list, caption_list, transform):
            self.images = image_list
            self.captions = caption_list
            self.transform = transform

        def __len__(self):
            return len(self.images)

        def __getitem__(self, idx):
            image = Image.open(self.images[idx]).convert("RGB")
            image = self.transform(image)
            return {"pixel_values": image, "input_ids": self.captions[idx]}

    dataset = SimpleDataset(valid_images, captions, transform)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True
    )

    # Otimizador
    optimizer = torch.optim.AdamW(unet.parameters(), lr=args.learning_rate)

    # Prepara com Accelerator
    unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)

    # Coloca VAE e Text Encoder em modo de avaliação (só UNET é treinado)
    vae.eval()
    text_encoder.eval()

    # Treinamento
    unet.train()
    step = 0
    for epoch in range(args.num_epochs):
        for batch in dataloader:
            with accelerator.accumulate(unet):
                # Gera latents
                pixel_values = batch["pixel_values"]
                latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215

                # Adiciona ruído
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                # Codifica texto
                inputs = tokenizer(
                    batch["input_ids"],
                    max_length=77,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt"
                ).to(latents.device)
                encoder_hidden_states = text_encoder(**inputs)[0]

                # Predição de ruído
                noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
                loss = torch.nn.functional.mse_loss(noise_pred, noise)

                # Backpropagation
                accelerator.backward(loss)
                optimizer.step()
                optimizer.zero_grad()
                step += 1

        print(f"Epoch {epoch+1}/{args.num_epochs} - Loss: {loss.item():.4f}")

    # Salva modelo LoRA
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        output_dir = args.output_dir
        unwrapped_unet = accelerator.unwrap_model(unet)
        unwrapped_unet.save_pretrained(output_dir)
        print(f"✅ Modelo LoRA salvo em: {output_dir}")

        # Opcional: salva também como safetensors
        from peft import save_model
        save_model(unwrapped_unet, output_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Treina um modelo LoRA para Stable Diffusion")
    parser.add_argument("--model_name", type=str, default="runwayml/stable-diffusion-v1-5", help="Modelo base do HF")
    parser.add_argument("--dataset_dir", type=str, required=True, help="Pasta com imagens e .txt")
    parser.add_argument("--output_dir", type=str, default="lora-output", help="Onde salvar o LoRA")
    parser.add_argument("--lora_rank", type=int, default=4, help="Rank LoRA (4-64)")
    parser.add_argument("--lora_alpha", type=int, default=32, help="Alpha LoRA")
    parser.add_argument("--learning_rate", type=float, default=1e-4, help="Taxa de aprendizado")
    parser.add_argument("--num_epochs", type=int, default=10, help="Número de épocas")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
    parser.add_argument("--mixed_precision", action="store_true", help="Usa FP16")

    args = parser.parse_args()
    main(args)