Allex21 commited on
Commit
7865b76
·
verified ·
1 Parent(s): 36c6206

Create train_lora.py

Browse files
Files changed (1) hide show
  1. train_lora.py +69 -0
train_lora.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from accelerate import Accelerator
4
+ from diffusers import StableDiffusionPipeline, UNet2DConditionModel
5
+ from peft import LoraConfig, get_peft_model
6
+ from transformers import AutoTokenizer, AutoModel
7
+
8
+ def main(args):
9
+ accelerator = Accelerator()
10
+
11
+ # Carrega modelo base
12
+ pipeline = StableDiffusionPipeline.from_pretrained(
13
+ args.model_name,
14
+ revision="fp16" if args.mixed_precision else None,
15
+ torch_dtype=torch.float16 if args.mixed_precision else None
16
+ )
17
+
18
+ # Configura LoRA
19
+ unet = pipeline.unet
20
+ lora_config = LoraConfig(
21
+ r=args.lora_rank,
22
+ lora_alpha=args.lora_alpha,
23
+ target_modules=["to_q", "to_v"],
24
+ lora_dropout=0.0,
25
+ bias="none"
26
+ )
27
+ unet = get_peft_model(unet, lora_config)
28
+
29
+ # Prepara dados
30
+ from datasets import load_dataset
31
+ dataset = load_dataset("imagefolder", data_dir=args.dataset_dir, split="train")
32
+
33
+ # Treinamento
34
+ optimizer = torch.optim.AdamW(unet.parameters(), lr=args.learning_rate)
35
+ train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size)
36
+
37
+ unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
38
+
39
+ for epoch in range(args.num_epochs):
40
+ for step, batch in enumerate(train_dataloader):
41
+ # Lógica de treinamento simplificada (para demonstração)
42
+ loss = unet(batch["pixel_values"]).sample.mean()
43
+ accelerator.backward(loss)
44
+ optimizer.step()
45
+ optimizer.zero_grad()
46
+
47
+ # Salva modelo
48
+ unet.save_pretrained(args.output_dir)
49
+ if args.push_to_hub:
50
+ from huggingface_hub import upload_folder
51
+ upload_folder(
52
+ repo_id=args.hub_model_id,
53
+ folder_path=args.output_dir,
54
+ commit_message=f"LoRA fine-tuning epoch {epoch}"
55
+ )
56
+
57
+ if __name__ == "__main__":
58
+ parser = argparse.ArgumentParser()
59
+ parser.add_argument("--model_name", type=str, default="runwayml/stable-diffusion-v1-5")
60
+ parser.add_argument("--dataset_dir", type=str, required=True)
61
+ parser.add_argument("--output_dir", type=str, default="lora-output")
62
+ parser.add_argument("--lora_rank", type=int, default=4)
63
+ parser.add_argument("--learning_rate", type=float, default=1e-4)
64
+ parser.add_argument("--num_epochs", type=int, default=10)
65
+ parser.add_argument("--batch_size", type=int, default=4)
66
+ parser.add_argument("--push_to_hub", action="store_true")
67
+ parser.add_argument("--hub_model_id", type=str, default="my-lora-model")
68
+ args = parser.parse_args()
69
+ main(args)