Tre / train_lora.py
Allex21's picture
Create train_lora.py
f3b067f verified
import argparse
import os
import torch
from diffusers import StableDiffusionPipeline
from peft import LoraConfig, get_peft_model
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
# Dataset customizado
class ImageDataset(Dataset):
def __init__(self, folder, size=512):
self.files = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith((".png", ".jpg", ".jpeg"))]
self.transform = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor()
])
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
img = Image.open(self.files[idx]).convert("RGB")
return self.transform(img)
def main(args):
# Carrega modelo base
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
# Configuração do LoRA
lora_config = LoraConfig(
r=args.rank,
lora_alpha=16,
target_modules=["to_q", "to_v"],
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
)
pipe.unet = get_peft_model(pipe.unet, lora_config)
# Dataset
dataset = ImageDataset(args.images_dir)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
# Otimizador
optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=args.learning_rate)
# Loop de treino
for epoch in range(args.num_epochs):
for batch in dataloader:
batch = batch.to("cuda")
noise = torch.randn_like(batch)
optimizer.zero_grad()
loss = pipe.unet(batch, noise)["loss"]
loss.backward()
optimizer.step()
print(f"✅ Epoch {epoch+1}/{args.num_epochs} finalizado.")
# Salvar LoRA
os.makedirs(args.output_dir, exist_ok=True)
torch.save(pipe.unet.state_dict(), os.path.join(args.output_dir, "lora.safetensors"))
print("✅ Treinamento concluído. Arquivo salvo em lora.safetensors")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--images_dir", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--num_epochs", type=int, default=10)
parser.add_argument("--rank", type=int, default=4)
args = parser.parse_args()
main(args)