AryanRathod3097 commited on
Commit
555e83c
·
verified ·
1 Parent(s): 34adbea

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +86 -0
train.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch
2
+ from datasets import load_dataset
3
+ from diffusers import StableDiffusionPipeline, DDPMScheduler
4
+ from peft import LoraConfig, get_peft_model
5
+ from accelerate import Accelerator
6
+ from PIL import Image
7
+ import numpy as np
8
+
9
+ # --- CONFIG -------------------------------------------------
10
+ RESOLUTION = 512
11
+ BATCH_SIZE = 1
12
+ GRAD_ACC = 4
13
+ LR = 1e-4
14
+ MAX_STEPS = 500
15
+ OUTPUT_DIR = "./lora-out"
16
+ MODEL_ID = "runwayml/stable-diffusion-v1-5"
17
+ # -----------------------------------------------------------
18
+
19
+ accelerator = Accelerator()
20
+
21
+ # 1. Dataset
22
+ dataset = load_dataset("huggan/smithsonian_butterflies_subset", split="train")
23
+
24
+ def transform(example):
25
+ image = example["image"].convert("RGB").resize((RESOLUTION, RESOLUTION))
26
+ return {"pixel_values": np.array(image).astype(np.float32) / 127.5 - 1.0}
27
+
28
+ dataset = dataset.map(transform, remove_columns=dataset.column_names)
29
+
30
+ # 2. Load SD pipeline
31
+ pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float32)
32
+ pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
33
+ pipe.vae.requires_grad_(False)
34
+ pipe.text_encoder.requires_grad_(False)
35
+ pipe.unet.requires_grad_(False)
36
+
37
+ # 3. Insert LoRA
38
+ lora_config = LoraConfig(
39
+ r=16, lora_alpha=16, target_modules=["to_k", "to_q", "to_v", "to_out.0"], init_lora_weights="gaussian"
40
+ )
41
+ pipe.unet = get_peft_model(pipe.unet, lora_config)
42
+
43
+ # 4. Optimizer & dataloader
44
+ from torch.utils.data import DataLoader
45
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
46
+ optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=LR)
47
+
48
+ pipe.unet, optimizer, dataloader = accelerator.prepare(pipe.unet, optimizer, dataloader)
49
+
50
+ # 5. Training loop
51
+ pipe.unet.train()
52
+ for step, batch in enumerate(dataloader, 1):
53
+ latents = pipe.vae.encode(batch["pixel_values"].to(pipe.vae.dtype)).latent_dist.sample()
54
+ latents = latents * pipe.vae.config.scaling_factor
55
+ noise = torch.randn_like(latents)
56
+ timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.shape[0],))
57
+ noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
58
+
59
+ encoder_hidden_states = pipe.text_encoder(
60
+ pipe.tokenizer(
61
+ ["a high-resolution photo of a butterfly"] * latents.shape[0],
62
+ padding="max_length",
63
+ max_length=pipe.tokenizer.model_max_length,
64
+ return_tensors="pt",
65
+ ).input_ids.to(pipe.text_encoder.device)
66
+ )[0]
67
+
68
+ model_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states).sample
69
+ loss = torch.nn.functional.mse_loss(model_pred, noise)
70
+ accelerator.backward(loss)
71
+ if step % GRAD_ACC == 0:
72
+ optimizer.step()
73
+ optimizer.zero_grad()
74
+
75
+ if accelerator.is_main_process and step % 50 == 0:
76
+ print(f"Step {step:04d}/{MAX_STEPS} | loss={loss.item():.4f}")
77
+ if step >= MAX_STEPS:
78
+ break
79
+
80
+ # 6. Save LoRA weights
81
+ accelerator.wait_for_everyone()
82
+ unwrapped = accelerator.unwrap_model(pipe.unet)
83
+ unwrapped.save_pretrained(OUTPUT_DIR)
84
+ print("LoRA saved to", OUTPUT_DIR)
85
+ # Move the file to the expected name for Gradio
86
+ os.rename(f"{OUTPUT_DIR}/adapter_model.safetensors", "./pytorch_lora_weights.safetensors")