File size: 6,278 Bytes
2a37b4f 7865b76 2a37b4f 06d0e1e 7865b76 2a37b4f 7865b76 dc3cfdb 7865b76 dc3cfdb 06d0e1e dc3cfdb 06d0e1e 2a37b4f dc3cfdb 2a37b4f 06d0e1e 2a37b4f 7865b76 2a37b4f 7865b76 dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f 7865b76 2a37b4f dc3cfdb 2a37b4f dc3cfdb 7865b76 2a37b4f dc3cfdb 06d0e1e dc3cfdb 2a37b4f 06d0e1e 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 2a37b4f dc3cfdb 7865b76 dc3cfdb 7865b76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# train_lora.py
import os
import torch
import argparse
from diffusers import StableDiffusionPipeline, DDPMScheduler
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator
from torchvision import transforms
from PIL import Image
import glob
def main(args):
# Inicializa o Accelerator
accelerator = Accelerator(
mixed_precision="fp16" if args.mixed_precision else None
)
print(f"🚀 Carregando modelo: {args.model_name}")
try:
pipe = StableDiffusionPipeline.from_pretrained(
args.model_name,
torch_dtype=torch.float16 if args.mixed_precision else torch.float32
)
except Exception as e:
print(f"❌ Falha ao carregar modelo: {e}")
return
# Extrai componentes
unet = pipe.unet
tokenizer = pipe.tokenizer
text_encoder = pipe.text_encoder
vae = pipe.vae
noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
# Configura LoRA
lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=["to_q", "to_v", "to_k", "to_out.0"],
lora_dropout=0.0,
bias="none"
)
unet = get_peft_model(unet, lora_config)
unet.print_trainable_parameters() # Mostra % de parâmetros treináveis
# Transformações de imagem
transform = transforms.Compose([
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
# === Carrega dataset ===
image_paths = []
for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.webp"]:
image_paths.extend(glob.glob(os.path.join(args.dataset_dir, ext)))
if len(image_paths) == 0:
print("❌ Nenhuma imagem encontrada no diretório!")
return
print(f"✅ {len(image_paths)} imagens encontradas. Carregando legendas...")
captions = []
valid_images = []
for img_path in image_paths:
txt_path = os.path.splitext(img_path)[0] + ".txt"
if os.path.exists(txt_path):
with open(txt_path, "r", encoding="utf-8") as f:
caption = f.read().strip()
else:
caption = "person"
captions.append(caption)
valid_images.append(img_path)
# Dataset PyTorch
class SimpleDataset(torch.utils.data.Dataset):
def __init__(self, image_list, caption_list, transform):
self.images = image_list
self.captions = caption_list
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = Image.open(self.images[idx]).convert("RGB")
image = self.transform(image)
return {"pixel_values": image, "input_ids": self.captions[idx]}
dataset = SimpleDataset(valid_images, captions, transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True
)
# Otimizador
optimizer = torch.optim.AdamW(unet.parameters(), lr=args.learning_rate)
# Prepara com Accelerator
unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)
# Coloca VAE e Text Encoder em modo de avaliação (só UNET é treinado)
vae.eval()
text_encoder.eval()
# Treinamento
unet.train()
step = 0
for epoch in range(args.num_epochs):
for batch in dataloader:
with accelerator.accumulate(unet):
# Gera latents
pixel_values = batch["pixel_values"]
latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215
# Adiciona ruído
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Codifica texto
inputs = tokenizer(
batch["input_ids"],
max_length=77,
padding="max_length",
truncation=True,
return_tensors="pt"
).to(latents.device)
encoder_hidden_states = text_encoder(**inputs)[0]
# Predição de ruído
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred, noise)
# Backpropagation
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
step += 1
print(f"Epoch {epoch+1}/{args.num_epochs} - Loss: {loss.item():.4f}")
# Salva modelo LoRA
accelerator.wait_for_everyone()
if accelerator.is_main_process:
output_dir = args.output_dir
unwrapped_unet = accelerator.unwrap_model(unet)
unwrapped_unet.save_pretrained(output_dir)
print(f"✅ Modelo LoRA salvo em: {output_dir}")
# Opcional: salva também como safetensors
from peft import save_model
save_model(unwrapped_unet, output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Treina um modelo LoRA para Stable Diffusion")
parser.add_argument("--model_name", type=str, default="runwayml/stable-diffusion-v1-5", help="Modelo base do HF")
parser.add_argument("--dataset_dir", type=str, required=True, help="Pasta com imagens e .txt")
parser.add_argument("--output_dir", type=str, default="lora-output", help="Onde salvar o LoRA")
parser.add_argument("--lora_rank", type=int, default=4, help="Rank LoRA (4-64)")
parser.add_argument("--lora_alpha", type=int, default=32, help="Alpha LoRA")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Taxa de aprendizado")
parser.add_argument("--num_epochs", type=int, default=10, help="Número de épocas")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--mixed_precision", action="store_true", help="Usa FP16")
args = parser.parse_args()
main(args) |