Spaces:
Runtime error
Runtime error
| import os, torch | |
| from datasets import load_dataset | |
| from diffusers import StableDiffusionPipeline, DDPMScheduler | |
| from peft import LoraConfig, get_peft_model | |
| from accelerate import Accelerator | |
| from PIL import Image | |
| import numpy as np | |
| # --- CONFIG ------------------------------------------------- | |
| RESOLUTION = 512 | |
| BATCH_SIZE = 1 | |
| GRAD_ACC = 4 | |
| LR = 1e-4 | |
| MAX_STEPS = 500 | |
| OUTPUT_DIR = "./lora-out" | |
| MODEL_ID = "runwayml/stable-diffusion-v1-5" | |
| # ----------------------------------------------------------- | |
| accelerator = Accelerator() | |
| # 1. Dataset | |
| dataset = load_dataset("huggan/smithsonian_butterflies_subset", split="train") | |
| def transform(example): | |
| image = example["image"].convert("RGB").resize((RESOLUTION, RESOLUTION)) | |
| return {"pixel_values": np.array(image).astype(np.float32) / 127.5 - 1.0} | |
| dataset = dataset.map(transform, remove_columns=dataset.column_names) | |
| # 2. Load SD pipeline | |
| pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float32) | |
| pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config) | |
| pipe.vae.requires_grad_(False) | |
| pipe.text_encoder.requires_grad_(False) | |
| pipe.unet.requires_grad_(False) | |
| # 3. Insert LoRA | |
| lora_config = LoraConfig( | |
| r=16, lora_alpha=16, target_modules=["to_k", "to_q", "to_v", "to_out.0"], init_lora_weights="gaussian" | |
| ) | |
| pipe.unet = get_peft_model(pipe.unet, lora_config) | |
| # 4. Optimizer & dataloader | |
| from torch.utils.data import DataLoader | |
| dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) | |
| optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=LR) | |
| pipe.unet, optimizer, dataloader = accelerator.prepare(pipe.unet, optimizer, dataloader) | |
| # 5. Training loop | |
| pipe.unet.train() | |
| for step, batch in enumerate(dataloader, 1): | |
| latents = pipe.vae.encode(batch["pixel_values"].to(pipe.vae.dtype)).latent_dist.sample() | |
| latents = latents * pipe.vae.config.scaling_factor | |
| noise = torch.randn_like(latents) | |
| timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.shape[0],)) | |
| noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps) | |
| encoder_hidden_states = pipe.text_encoder( | |
| pipe.tokenizer( | |
| ["a high-resolution photo of a butterfly"] * latents.shape[0], | |
| padding="max_length", | |
| max_length=pipe.tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ).input_ids.to(pipe.text_encoder.device) | |
| )[0] | |
| model_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states).sample | |
| loss = torch.nn.functional.mse_loss(model_pred, noise) | |
| accelerator.backward(loss) | |
| if step % GRAD_ACC == 0: | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| if accelerator.is_main_process and step % 50 == 0: | |
| print(f"Step {step:04d}/{MAX_STEPS} | loss={loss.item():.4f}") | |
| if step >= MAX_STEPS: | |
| break | |
| # 6. Save LoRA weights | |
| accelerator.wait_for_everyone() | |
| unwrapped = accelerator.unwrap_model(pipe.unet) | |
| unwrapped.save_pretrained(OUTPUT_DIR) | |
| print("LoRA saved to", OUTPUT_DIR) | |
| # Move the file to the expected name for Gradio | |
| os.rename(f"{OUTPUT_DIR}/adapter_model.safetensors", "./pytorch_lora_weights.safetensors") |