Allex21 commited on
Commit
dc3cfdb
·
verified ·
1 Parent(s): 58c1f99

Update train_lora.py

Browse files
Files changed (1) hide show
  1. train_lora.py +82 -52
train_lora.py CHANGED
@@ -9,22 +9,28 @@ from torchvision import transforms
9
  from PIL import Image
10
  import glob
11
 
 
12
  def main(args):
 
13
  accelerator = Accelerator(
14
- mixed_precision="fp16" if args.mixed_precision else None,
15
- gradient_accumulation_steps=1
16
  )
17
 
18
- # Carrega pipeline
19
- print("Carregando modelo base...")
20
- pipe = StableDiffusionPipeline.from_pretrained(
21
- args.model_name,
22
- torch_dtype=torch.float16 if args.mixed_precision else torch.float32
23
- )
 
 
 
 
 
 
24
  tokenizer = pipe.tokenizer
25
  text_encoder = pipe.text_encoder
26
  vae = pipe.vae
27
- unet = pipe.unet
28
  noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
29
 
30
  # Configura LoRA
@@ -36,9 +42,9 @@ def main(args):
36
  bias="none"
37
  )
38
  unet = get_peft_model(unet, lora_config)
39
- unet.print_trainable_parameters()
40
 
41
- # Transformações
42
  transform = transforms.Compose([
43
  transforms.Resize(512),
44
  transforms.CenterCrop(512),
@@ -46,9 +52,16 @@ def main(args):
46
  transforms.Normalize([0.5], [0.5]),
47
  ])
48
 
49
- # Carrega imagens e legendas
50
- image_paths = sorted(glob.glob(os.path.join(args.dataset_dir, "*.*")))
51
- image_paths = [p for p in image_paths if p.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
 
 
 
 
 
 
 
52
 
53
  captions = []
54
  valid_images = []
@@ -56,89 +69,106 @@ def main(args):
56
  txt_path = os.path.splitext(img_path)[0] + ".txt"
57
  if os.path.exists(txt_path):
58
  with open(txt_path, "r", encoding="utf-8") as f:
59
- captions.append(f.read().strip())
60
  else:
61
- captions.append("person")
 
62
  valid_images.append(img_path)
63
 
64
- if len(valid_images) == 0:
65
- print("❌ Nenhuma imagem encontrada!")
66
- return
67
-
68
- print(f"✅ {len(valid_images)} imagens carregadas")
69
-
70
  class SimpleDataset(torch.utils.data.Dataset):
71
- def __init__(self, image_paths, captions, transform):
72
- self.image_paths = image_paths
73
- self.captions = captions
74
  self.transform = transform
75
 
76
  def __len__(self):
77
- return len(self.image_paths)
78
 
79
  def __getitem__(self, idx):
80
- image = Image.open(self.image_paths[idx]).convert("RGB")
81
  image = self.transform(image)
82
- caption = self.captions[idx]
83
- return {"pixel_values": image, "input_ids": caption}
84
 
85
  dataset = SimpleDataset(valid_images, captions, transform)
86
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
 
 
 
 
87
 
88
  # Otimizador
89
  optimizer = torch.optim.AdamW(unet.parameters(), lr=args.learning_rate)
90
- lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
91
 
92
- unet, optimizer, dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, dataloader, lr_scheduler)
 
 
 
 
 
93
 
94
  # Treinamento
95
  unet.train()
96
- global_step = 0
97
  for epoch in range(args.num_epochs):
98
  for batch in dataloader:
99
  with accelerator.accumulate(unet):
100
- pixel_values = batch["pixel_values"].to(accelerator.device)
 
101
  latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215
102
 
 
103
  noise = torch.randn_like(latents)
104
  bsz = latents.shape[0]
105
  timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
106
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
107
 
108
- encoder_hidden_states = text_encoder(tokenizer(
 
109
  batch["input_ids"],
110
- padding="max_length",
111
  max_length=77,
 
112
  truncation=True,
113
  return_tensors="pt"
114
- ).input_ids.to(latents.device))[0]
 
115
 
 
116
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
117
  loss = torch.nn.functional.mse_loss(noise_pred, noise)
118
 
 
119
  accelerator.backward(loss)
120
  optimizer.step()
121
- lr_scheduler.step()
122
  optimizer.zero_grad()
123
- global_step += 1
124
 
125
- # Salva modelo
 
 
126
  accelerator.wait_for_everyone()
127
  if accelerator.is_main_process:
 
128
  unwrapped_unet = accelerator.unwrap_model(unet)
129
- unwrapped_unet.save_pretrained(args.output_dir)
130
- print(f"✅ Modelo salvo em {args.output_dir}")
 
 
 
 
 
131
 
132
  if __name__ == "__main__":
133
- parser = argparse.ArgumentParser()
134
- parser.add_argument("--model_name", type=str, default="runwayml/stable-diffusion-v1-5")
135
- parser.add_argument("--dataset_dir", type=str, required=True)
136
- parser.add_argument("--output_dir", type=str, default="lora-output")
137
- parser.add_argument("--lora_rank", type=int, default=4)
138
- parser.add_argument("--lora_alpha", type=int, default=32)
139
- parser.add_argument("--learning_rate", type=float, default=1e-4)
140
- parser.add_argument("--num_epochs", type=int, default=10)
141
- parser.add_argument("--batch_size", type=int, default=1)
142
- parser.add_argument("--mixed_precision", action="store_true")
 
143
  args = parser.parse_args()
144
  main(args)
 
9
  from PIL import Image
10
  import glob
11
 
12
+
13
  def main(args):
14
+ # Inicializa o Accelerator
15
  accelerator = Accelerator(
16
+ mixed_precision="fp16" if args.mixed_precision else None
 
17
  )
18
 
19
+ print(f"🚀 Carregando modelo: {args.model_name}")
20
+ try:
21
+ pipe = StableDiffusionPipeline.from_pretrained(
22
+ args.model_name,
23
+ torch_dtype=torch.float16 if args.mixed_precision else torch.float32
24
+ )
25
+ except Exception as e:
26
+ print(f"❌ Falha ao carregar modelo: {e}")
27
+ return
28
+
29
+ # Extrai componentes
30
+ unet = pipe.unet
31
  tokenizer = pipe.tokenizer
32
  text_encoder = pipe.text_encoder
33
  vae = pipe.vae
 
34
  noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
35
 
36
  # Configura LoRA
 
42
  bias="none"
43
  )
44
  unet = get_peft_model(unet, lora_config)
45
+ unet.print_trainable_parameters() # Mostra % de parâmetros treináveis
46
 
47
+ # Transformações de imagem
48
  transform = transforms.Compose([
49
  transforms.Resize(512),
50
  transforms.CenterCrop(512),
 
52
  transforms.Normalize([0.5], [0.5]),
53
  ])
54
 
55
+ # === Carrega dataset ===
56
+ image_paths = []
57
+ for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.webp"]:
58
+ image_paths.extend(glob.glob(os.path.join(args.dataset_dir, ext)))
59
+
60
+ if len(image_paths) == 0:
61
+ print("❌ Nenhuma imagem encontrada no diretório!")
62
+ return
63
+
64
+ print(f"✅ {len(image_paths)} imagens encontradas. Carregando legendas...")
65
 
66
  captions = []
67
  valid_images = []
 
69
  txt_path = os.path.splitext(img_path)[0] + ".txt"
70
  if os.path.exists(txt_path):
71
  with open(txt_path, "r", encoding="utf-8") as f:
72
+ caption = f.read().strip()
73
  else:
74
+ caption = "person"
75
+ captions.append(caption)
76
  valid_images.append(img_path)
77
 
78
+ # Dataset PyTorch
 
 
 
 
 
79
  class SimpleDataset(torch.utils.data.Dataset):
80
+ def __init__(self, image_list, caption_list, transform):
81
+ self.images = image_list
82
+ self.captions = caption_list
83
  self.transform = transform
84
 
85
  def __len__(self):
86
+ return len(self.images)
87
 
88
  def __getitem__(self, idx):
89
+ image = Image.open(self.images[idx]).convert("RGB")
90
  image = self.transform(image)
91
+ return {"pixel_values": image, "input_ids": self.captions[idx]}
 
92
 
93
  dataset = SimpleDataset(valid_images, captions, transform)
94
+ dataloader = torch.utils.data.DataLoader(
95
+ dataset,
96
+ batch_size=args.batch_size,
97
+ shuffle=True
98
+ )
99
 
100
  # Otimizador
101
  optimizer = torch.optim.AdamW(unet.parameters(), lr=args.learning_rate)
 
102
 
103
+ # Prepara com Accelerator
104
+ unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)
105
+
106
+ # Coloca VAE e Text Encoder em modo de avaliação (só UNET é treinado)
107
+ vae.eval()
108
+ text_encoder.eval()
109
 
110
  # Treinamento
111
  unet.train()
112
+ step = 0
113
  for epoch in range(args.num_epochs):
114
  for batch in dataloader:
115
  with accelerator.accumulate(unet):
116
+ # Gera latents
117
+ pixel_values = batch["pixel_values"]
118
  latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215
119
 
120
+ # Adiciona ruído
121
  noise = torch.randn_like(latents)
122
  bsz = latents.shape[0]
123
  timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
124
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
125
 
126
+ # Codifica texto
127
+ inputs = tokenizer(
128
  batch["input_ids"],
 
129
  max_length=77,
130
+ padding="max_length",
131
  truncation=True,
132
  return_tensors="pt"
133
+ ).to(latents.device)
134
+ encoder_hidden_states = text_encoder(**inputs)[0]
135
 
136
+ # Predição de ruído
137
  noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
138
  loss = torch.nn.functional.mse_loss(noise_pred, noise)
139
 
140
+ # Backpropagation
141
  accelerator.backward(loss)
142
  optimizer.step()
 
143
  optimizer.zero_grad()
144
+ step += 1
145
 
146
+ print(f"Epoch {epoch+1}/{args.num_epochs} - Loss: {loss.item():.4f}")
147
+
148
+ # Salva modelo LoRA
149
  accelerator.wait_for_everyone()
150
  if accelerator.is_main_process:
151
+ output_dir = args.output_dir
152
  unwrapped_unet = accelerator.unwrap_model(unet)
153
+ unwrapped_unet.save_pretrained(output_dir)
154
+ print(f"✅ Modelo LoRA salvo em: {output_dir}")
155
+
156
+ # Opcional: salva também como safetensors
157
+ from peft import save_model
158
+ save_model(unwrapped_unet, output_dir)
159
+
160
 
161
  if __name__ == "__main__":
162
+ parser = argparse.ArgumentParser(description="Treina um modelo LoRA para Stable Diffusion")
163
+ parser.add_argument("--model_name", type=str, default="runwayml/stable-diffusion-v1-5", help="Modelo base do HF")
164
+ parser.add_argument("--dataset_dir", type=str, required=True, help="Pasta com imagens e .txt")
165
+ parser.add_argument("--output_dir", type=str, default="lora-output", help="Onde salvar o LoRA")
166
+ parser.add_argument("--lora_rank", type=int, default=4, help="Rank LoRA (4-64)")
167
+ parser.add_argument("--lora_alpha", type=int, default=32, help="Alpha LoRA")
168
+ parser.add_argument("--learning_rate", type=float, default=1e-4, help="Taxa de aprendizado")
169
+ parser.add_argument("--num_epochs", type=int, default=10, help="Número de épocas")
170
+ parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
171
+ parser.add_argument("--mixed_precision", action="store_true", help="Usa FP16")
172
+
173
  args = parser.parse_args()
174
  main(args)