Trabre / train_lora.py
Allex21's picture
Update train_lora.py
dc3cfdb verified
# train_lora.py
import os
import torch
import argparse
from diffusers import StableDiffusionPipeline, DDPMScheduler
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator
from torchvision import transforms
from PIL import Image
import glob
def main(args):
# Inicializa o Accelerator
accelerator = Accelerator(
mixed_precision="fp16" if args.mixed_precision else None
)
print(f"🚀 Carregando modelo: {args.model_name}")
try:
pipe = StableDiffusionPipeline.from_pretrained(
args.model_name,
torch_dtype=torch.float16 if args.mixed_precision else torch.float32
)
except Exception as e:
print(f"❌ Falha ao carregar modelo: {e}")
return
# Extrai componentes
unet = pipe.unet
tokenizer = pipe.tokenizer
text_encoder = pipe.text_encoder
vae = pipe.vae
noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
# Configura LoRA
lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=["to_q", "to_v", "to_k", "to_out.0"],
lora_dropout=0.0,
bias="none"
)
unet = get_peft_model(unet, lora_config)
unet.print_trainable_parameters() # Mostra % de parâmetros treináveis
# Transformações de imagem
transform = transforms.Compose([
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
# === Carrega dataset ===
image_paths = []
for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.webp"]:
image_paths.extend(glob.glob(os.path.join(args.dataset_dir, ext)))
if len(image_paths) == 0:
print("❌ Nenhuma imagem encontrada no diretório!")
return
print(f"✅ {len(image_paths)} imagens encontradas. Carregando legendas...")
captions = []
valid_images = []
for img_path in image_paths:
txt_path = os.path.splitext(img_path)[0] + ".txt"
if os.path.exists(txt_path):
with open(txt_path, "r", encoding="utf-8") as f:
caption = f.read().strip()
else:
caption = "person"
captions.append(caption)
valid_images.append(img_path)
# Dataset PyTorch
class SimpleDataset(torch.utils.data.Dataset):
def __init__(self, image_list, caption_list, transform):
self.images = image_list
self.captions = caption_list
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = Image.open(self.images[idx]).convert("RGB")
image = self.transform(image)
return {"pixel_values": image, "input_ids": self.captions[idx]}
dataset = SimpleDataset(valid_images, captions, transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True
)
# Otimizador
optimizer = torch.optim.AdamW(unet.parameters(), lr=args.learning_rate)
# Prepara com Accelerator
unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)
# Coloca VAE e Text Encoder em modo de avaliação (só UNET é treinado)
vae.eval()
text_encoder.eval()
# Treinamento
unet.train()
step = 0
for epoch in range(args.num_epochs):
for batch in dataloader:
with accelerator.accumulate(unet):
# Gera latents
pixel_values = batch["pixel_values"]
latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215
# Adiciona ruído
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Codifica texto
inputs = tokenizer(
batch["input_ids"],
max_length=77,
padding="max_length",
truncation=True,
return_tensors="pt"
).to(latents.device)
encoder_hidden_states = text_encoder(**inputs)[0]
# Predição de ruído
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred, noise)
# Backpropagation
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
step += 1
print(f"Epoch {epoch+1}/{args.num_epochs} - Loss: {loss.item():.4f}")
# Salva modelo LoRA
accelerator.wait_for_everyone()
if accelerator.is_main_process:
output_dir = args.output_dir
unwrapped_unet = accelerator.unwrap_model(unet)
unwrapped_unet.save_pretrained(output_dir)
print(f"✅ Modelo LoRA salvo em: {output_dir}")
# Opcional: salva também como safetensors
from peft import save_model
save_model(unwrapped_unet, output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Treina um modelo LoRA para Stable Diffusion")
parser.add_argument("--model_name", type=str, default="runwayml/stable-diffusion-v1-5", help="Modelo base do HF")
parser.add_argument("--dataset_dir", type=str, required=True, help="Pasta com imagens e .txt")
parser.add_argument("--output_dir", type=str, default="lora-output", help="Onde salvar o LoRA")
parser.add_argument("--lora_rank", type=int, default=4, help="Rank LoRA (4-64)")
parser.add_argument("--lora_alpha", type=int, default=32, help="Alpha LoRA")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Taxa de aprendizado")
parser.add_argument("--num_epochs", type=int, default=10, help="Número de épocas")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--mixed_precision", action="store_true", help="Usa FP16")
args = parser.parse_args()
main(args)