Allex21 commited on
Commit
2a37b4f
·
verified ·
1 Parent(s): 870bf43

Update train_lora.py

Browse files
Files changed (1) hide show
  1. train_lora.py +109 -41
train_lora.py CHANGED
@@ -1,58 +1,126 @@
 
1
  import os
2
- import argparse
3
- from accelerate import Accelerator
4
- from diffusers import StableDiffusionPipeline, UNet2DConditionModel
5
  from peft import LoraConfig, get_peft_model
6
- from transformers import AutoTokenizer, AutoModel
 
 
 
 
 
7
 
8
  def main(args):
9
- accelerator = Accelerator()
10
-
11
- # Carrega modelo base
12
- pipeline = StableDiffusionPipeline.from_pretrained(
 
13
  args.model_name,
14
- revision="fp16" if args.mixed_precision else None,
15
- torch_dtype=torch.float16 if args.mixed_precision else None
16
  )
17
-
 
 
 
 
18
  # Configura LoRA
19
- unet = pipeline.unet
20
  lora_config = LoraConfig(
21
  r=args.lora_rank,
22
  lora_alpha=args.lora_alpha,
23
- target_modules=["to_q", "to_v"],
24
  lora_dropout=0.0,
25
  bias="none"
26
  )
27
  unet = get_peft_model(unet, lora_config)
28
-
29
- # Prepara dados
30
- from datasets import load_dataset
31
- dataset = load_dataset("imagefolder", data_dir=args.dataset_dir, split="train")
32
-
33
- # Treinamento
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  optimizer = torch.optim.AdamW(unet.parameters(), lr=args.learning_rate)
35
- train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size)
36
-
37
- unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
38
-
 
 
39
  for epoch in range(args.num_epochs):
40
- for step, batch in enumerate(train_dataloader):
41
- # Lógica de treinamento simplificada (para demonstração)
42
- loss = unet(batch["pixel_values"]).sample.mean()
43
- accelerator.backward(loss)
44
- optimizer.step()
45
- optimizer.zero_grad()
46
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Salva modelo
48
- unet.save_pretrained(args.output_dir)
49
- if args.push_to_hub:
50
- from huggingface_hub import upload_folder
51
- upload_folder(
52
- repo_id=args.hub_model_id,
53
- folder_path=args.output_dir,
54
- commit_message=f"LoRA fine-tuning epoch {epoch}"
55
- )
56
 
57
  if __name__ == "__main__":
58
  parser = argparse.ArgumentParser()
@@ -60,10 +128,10 @@ if __name__ == "__main__":
60
  parser.add_argument("--dataset_dir", type=str, required=True)
61
  parser.add_argument("--output_dir", type=str, default="lora-output")
62
  parser.add_argument("--lora_rank", type=int, default=4)
 
63
  parser.add_argument("--learning_rate", type=float, default=1e-4)
64
  parser.add_argument("--num_epochs", type=int, default=10)
65
- parser.add_argument("--batch_size", type=int, default=4)
66
- parser.add_argument("--push_to_hub", action="store_true")
67
- parser.add_argument("--hub_model_id", type=str, default="my-lora-model")
68
  args = parser.parse_args()
69
  main(args)
 
1
+ # train_lora.py
2
  import os
3
+ import torch
4
+ from diffusers import StableDiffusionPipeline
 
5
  from peft import LoraConfig, get_peft_model
6
+ from diffusers.optimization import get_scheduler
7
+ from accelerate import Accelerator
8
+ from torchvision import transforms
9
+ from PIL import Image
10
+ import argparse
11
+ import glob
12
 
13
  def main(args):
14
+ accelerator = Accelerator(mixed_precision="fp16" if args.mixed_precision else None)
15
+
16
+ # Carrega pipeline
17
+ print("Carregando modelo base...")
18
+ pipe = StableDiffusionPipeline.from_pretrained(
19
  args.model_name,
20
+ torch_dtype=torch.float16 if args.mixed_precision else torch.float32
 
21
  )
22
+ tokenizer = pipe.tokenizer
23
+ text_encoder = pipe.text_encoder
24
+ vae = pipe.vae
25
+ unet = pipe.unet
26
+
27
  # Configura LoRA
 
28
  lora_config = LoraConfig(
29
  r=args.lora_rank,
30
  lora_alpha=args.lora_alpha,
31
+ target_modules=["to_q", "to_v", "to_k", "to_out.0"],
32
  lora_dropout=0.0,
33
  bias="none"
34
  )
35
  unet = get_peft_model(unet, lora_config)
36
+
37
+ # Transformações
38
+ transform = transforms.Compose([
39
+ transforms.Resize(512),
40
+ transforms.CenterCrop(512),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize([0.5], [0.5]),
43
+ ])
44
+
45
+ # Carrega imagens e legendas
46
+ image_paths = sorted(glob.glob(os.path.join(args.dataset_dir, "*.*")))
47
+ image_paths = [p for p in image_paths if p.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
48
+
49
+ captions = []
50
+ valid_images = []
51
+ for img_path in image_paths:
52
+ txt_path = os.path.splitext(img_path)[0] + ".txt"
53
+ if os.path.exists(txt_path):
54
+ with open(txt_path, "r", encoding="utf-8") as f:
55
+ captions.append(f.read().strip())
56
+ else:
57
+ captions.append("person")
58
+ valid_images.append(img_path)
59
+
60
+ if len(valid_images) == 0:
61
+ print("❌ Nenhuma imagem encontrada!")
62
+ return
63
+
64
+ print(f"✅ {len(valid_images)} imagens carregadas")
65
+
66
+ # Dataset simples
67
+ class SimpleDataset(torch.utils.data.Dataset):
68
+ def __init__(self, image_paths, captions, transform):
69
+ self.image_paths = image_paths
70
+ self.captions = captions
71
+ self.transform = transform
72
+
73
+ def __len__(self):
74
+ return len(self.image_paths)
75
+
76
+ def __getitem__(self, idx):
77
+ image = Image.open(self.image_paths[idx]).convert("RGB")
78
+ image = self.transform(image)
79
+ caption = self.captions[idx]
80
+ return {"pixel_values": image, "input_ids": caption}
81
+
82
+ dataset = SimpleDataset(valid_images, captions, transform)
83
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
84
+
85
+ # Otimizador
86
  optimizer = torch.optim.AdamW(unet.parameters(), lr=args.learning_rate)
87
+ lr_scheduler = get_scheduler("constant", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(dataloader) * args.num_epochs)
88
+
89
+ unet, optimizer, dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, dataloader, lr_scheduler)
90
+
91
+ # Treinamento
92
+ unet.train()
93
  for epoch in range(args.num_epochs):
94
+ for batch in dataloader:
95
+ with accelerator.accumulate(unet):
96
+ latents = vae.encode(batch["pixel_values"]).latent_dist.sample() * 0.18215
97
+ noise = torch.randn_like(latents)
98
+ bsz = latents.shape[0]
99
+ timesteps = torch.randint(0, 1000, (bsz,), device=latents.device)
100
+ noisy_latents = latents + noise * torch.sqrt(timesteps / 1000)
101
+
102
+ encoder_hidden_states = text_encoder(tokenizer(
103
+ batch["input_ids"],
104
+ padding="max_length",
105
+ max_length=77,
106
+ truncation=True,
107
+ return_tensors="pt"
108
+ ).input_ids.to(latents.device))[0]
109
+
110
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
111
+ loss = torch.nn.functional.mse_loss(noise_pred, noise)
112
+
113
+ accelerator.backward(loss)
114
+ optimizer.step()
115
+ lr_scheduler.step()
116
+ optimizer.zero_grad()
117
+
118
  # Salva modelo
119
+ accelerator.wait_for_everyone()
120
+ if accelerator.is_main_process:
121
+ unwrapped_unet = accelerator.unwrap_model(unet)
122
+ unwrapped_unet.save_pretrained(args.output_dir)
123
+ print(f"✅ Modelo salvo em {args.output_dir}")
 
 
 
124
 
125
  if __name__ == "__main__":
126
  parser = argparse.ArgumentParser()
 
128
  parser.add_argument("--dataset_dir", type=str, required=True)
129
  parser.add_argument("--output_dir", type=str, default="lora-output")
130
  parser.add_argument("--lora_rank", type=int, default=4)
131
+ parser.add_argument("--lora_alpha", type=int, default=32)
132
  parser.add_argument("--learning_rate", type=float, default=1e-4)
133
  parser.add_argument("--num_epochs", type=int, default=10)
134
+ parser.add_argument("--batch_size", type=int, default=1)
135
+ parser.add_argument("--mixed_precision", action="store_true")
 
136
  args = parser.parse_args()
137
  main(args)