butterfly-lora / train.py
AryanRathod3097's picture
Create train.py
555e83c verified
import os, torch
from datasets import load_dataset
from diffusers import StableDiffusionPipeline, DDPMScheduler
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator
from PIL import Image
import numpy as np
# --- CONFIG -------------------------------------------------
RESOLUTION = 512
BATCH_SIZE = 1
GRAD_ACC = 4
LR = 1e-4
MAX_STEPS = 500
OUTPUT_DIR = "./lora-out"
MODEL_ID = "runwayml/stable-diffusion-v1-5"
# -----------------------------------------------------------
accelerator = Accelerator()
# 1. Dataset
dataset = load_dataset("huggan/smithsonian_butterflies_subset", split="train")
def transform(example):
image = example["image"].convert("RGB").resize((RESOLUTION, RESOLUTION))
return {"pixel_values": np.array(image).astype(np.float32) / 127.5 - 1.0}
dataset = dataset.map(transform, remove_columns=dataset.column_names)
# 2. Load SD pipeline
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float32)
pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
pipe.vae.requires_grad_(False)
pipe.text_encoder.requires_grad_(False)
pipe.unet.requires_grad_(False)
# 3. Insert LoRA
lora_config = LoraConfig(
r=16, lora_alpha=16, target_modules=["to_k", "to_q", "to_v", "to_out.0"], init_lora_weights="gaussian"
)
pipe.unet = get_peft_model(pipe.unet, lora_config)
# 4. Optimizer & dataloader
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=LR)
pipe.unet, optimizer, dataloader = accelerator.prepare(pipe.unet, optimizer, dataloader)
# 5. Training loop
pipe.unet.train()
for step, batch in enumerate(dataloader, 1):
latents = pipe.vae.encode(batch["pixel_values"].to(pipe.vae.dtype)).latent_dist.sample()
latents = latents * pipe.vae.config.scaling_factor
noise = torch.randn_like(latents)
timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.shape[0],))
noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
encoder_hidden_states = pipe.text_encoder(
pipe.tokenizer(
["a high-resolution photo of a butterfly"] * latents.shape[0],
padding="max_length",
max_length=pipe.tokenizer.model_max_length,
return_tensors="pt",
).input_ids.to(pipe.text_encoder.device)
)[0]
model_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(model_pred, noise)
accelerator.backward(loss)
if step % GRAD_ACC == 0:
optimizer.step()
optimizer.zero_grad()
if accelerator.is_main_process and step % 50 == 0:
print(f"Step {step:04d}/{MAX_STEPS} | loss={loss.item():.4f}")
if step >= MAX_STEPS:
break
# 6. Save LoRA weights
accelerator.wait_for_everyone()
unwrapped = accelerator.unwrap_model(pipe.unet)
unwrapped.save_pretrained(OUTPUT_DIR)
print("LoRA saved to", OUTPUT_DIR)
# Move the file to the expected name for Gradio
os.rename(f"{OUTPUT_DIR}/adapter_model.safetensors", "./pytorch_lora_weights.safetensors")