Allex21 commited on
Commit
06d0e1e
·
verified ·
1 Parent(s): c6fc478

Update train_lora.py

Browse files
Files changed (1) hide show
  1. train_lora.py +16 -9
train_lora.py CHANGED
@@ -1,17 +1,19 @@
1
  # train_lora.py
2
  import os
3
  import torch
4
- from diffusers import StableDiffusionPipeline
 
5
  from peft import LoraConfig, get_peft_model
6
- from diffusers.optimization import get_scheduler
7
  from accelerate import Accelerator
8
  from torchvision import transforms
9
  from PIL import Image
10
- import argparse
11
  import glob
12
 
13
  def main(args):
14
- accelerator = Accelerator(mixed_precision="fp16" if args.mixed_precision else None)
 
 
 
15
 
16
  # Carrega pipeline
17
  print("Carregando modelo base...")
@@ -23,6 +25,7 @@ def main(args):
23
  text_encoder = pipe.text_encoder
24
  vae = pipe.vae
25
  unet = pipe.unet
 
26
 
27
  # Configura LoRA
28
  lora_config = LoraConfig(
@@ -33,6 +36,7 @@ def main(args):
33
  bias="none"
34
  )
35
  unet = get_peft_model(unet, lora_config)
 
36
 
37
  # Transformações
38
  transform = transforms.Compose([
@@ -63,7 +67,6 @@ def main(args):
63
 
64
  print(f"✅ {len(valid_images)} imagens carregadas")
65
 
66
- # Dataset simples
67
  class SimpleDataset(torch.utils.data.Dataset):
68
  def __init__(self, image_paths, captions, transform):
69
  self.image_paths = image_paths
@@ -84,20 +87,23 @@ def main(args):
84
 
85
  # Otimizador
86
  optimizer = torch.optim.AdamW(unet.parameters(), lr=args.learning_rate)
87
- lr_scheduler = get_scheduler("constant", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(dataloader) * args.num_epochs)
88
 
89
  unet, optimizer, dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, dataloader, lr_scheduler)
90
 
91
  # Treinamento
92
  unet.train()
 
93
  for epoch in range(args.num_epochs):
94
  for batch in dataloader:
95
  with accelerator.accumulate(unet):
96
- latents = vae.encode(batch["pixel_values"]).latent_dist.sample() * 0.18215
 
 
97
  noise = torch.randn_like(latents)
98
  bsz = latents.shape[0]
99
- timesteps = torch.randint(0, 1000, (bsz,), device=latents.device)
100
- noisy_latents = latents + noise * torch.sqrt(timesteps / 1000)
101
 
102
  encoder_hidden_states = text_encoder(tokenizer(
103
  batch["input_ids"],
@@ -114,6 +120,7 @@ def main(args):
114
  optimizer.step()
115
  lr_scheduler.step()
116
  optimizer.zero_grad()
 
117
 
118
  # Salva modelo
119
  accelerator.wait_for_everyone()
 
1
  # train_lora.py
2
  import os
3
  import torch
4
+ import argparse
5
+ from diffusers import StableDiffusionPipeline, DDPMScheduler
6
  from peft import LoraConfig, get_peft_model
 
7
  from accelerate import Accelerator
8
  from torchvision import transforms
9
  from PIL import Image
 
10
  import glob
11
 
12
  def main(args):
13
+ accelerator = Accelerator(
14
+ mixed_precision="fp16" if args.mixed_precision else None,
15
+ gradient_accumulation_steps=1
16
+ )
17
 
18
  # Carrega pipeline
19
  print("Carregando modelo base...")
 
25
  text_encoder = pipe.text_encoder
26
  vae = pipe.vae
27
  unet = pipe.unet
28
+ noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
29
 
30
  # Configura LoRA
31
  lora_config = LoraConfig(
 
36
  bias="none"
37
  )
38
  unet = get_peft_model(unet, lora_config)
39
+ unet.print_trainable_parameters()
40
 
41
  # Transformações
42
  transform = transforms.Compose([
 
67
 
68
  print(f"✅ {len(valid_images)} imagens carregadas")
69
 
 
70
  class SimpleDataset(torch.utils.data.Dataset):
71
  def __init__(self, image_paths, captions, transform):
72
  self.image_paths = image_paths
 
87
 
88
  # Otimizador
89
  optimizer = torch.optim.AdamW(unet.parameters(), lr=args.learning_rate)
90
+ lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
91
 
92
  unet, optimizer, dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, dataloader, lr_scheduler)
93
 
94
  # Treinamento
95
  unet.train()
96
+ global_step = 0
97
  for epoch in range(args.num_epochs):
98
  for batch in dataloader:
99
  with accelerator.accumulate(unet):
100
+ pixel_values = batch["pixel_values"].to(accelerator.device)
101
+ latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215
102
+
103
  noise = torch.randn_like(latents)
104
  bsz = latents.shape[0]
105
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
106
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
107
 
108
  encoder_hidden_states = text_encoder(tokenizer(
109
  batch["input_ids"],
 
120
  optimizer.step()
121
  lr_scheduler.step()
122
  optimizer.zero_grad()
123
+ global_step += 1
124
 
125
  # Salva modelo
126
  accelerator.wait_for_everyone()