Allex21 commited on
Commit
525da55
·
verified ·
1 Parent(s): a01e5d2

Upload 2 files

Browse files
Files changed (2) hide show
  1. app_1.py +242 -0
  2. requirements.txt +13 -0
app_1.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+ from accelerate import Accelerator
6
+ from accelerate.utils import set_seed
7
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, StableDiffusionPipeline
8
+ from diffusers.optimization import get_scheduler
9
+ from diffusers.training_utils import EMAModel
10
+ from diffusers.models.attention_processor import LoRAAttnProcessor as DiffusersLoRAAttnProcessor
11
+ from huggingface_hub import create_repo, upload_folder
12
+ from PIL import Image
13
+ from torch.utils.data import Dataset
14
+ from torchvision import transforms
15
+ from tqdm.auto import tqdm
16
+ from transformers import CLIPTextModel, CLIPTokenizer
17
+ import zipfile
18
+ import shutil
19
+ from safetensors.torch import save_file
20
+
21
+ # Placeholder para o script de treinamento
22
+ def train_lora(
23
+ instance_data_dir: str,
24
+ output_dir: str,
25
+ resolution: int = 512,
26
+ learning_rate: float = 1e-4,
27
+ batch_size: int = 1,
28
+ num_epochs: int = 1,
29
+ train_prompt: str = "a photo of sks dog",
30
+ pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5",
31
+ ):
32
+ # Configurações básicas
33
+ accelerator = Accelerator(
34
+ gradient_accumulation_steps=1,
35
+ mixed_precision="fp16",
36
+ )
37
+
38
+ # Carregar tokenizer e modelo base
39
+ tokenizer = CLIPTokenizer.from_pretrained(
40
+ pretrained_model_name_or_path, subfolder="tokenizer"
41
+ )
42
+ text_encoder = CLIPTextModel.from_pretrained(
43
+ pretrained_model_name_or_path, subfolder="text_encoder"
44
+ )
45
+ vae = AutoencoderKL.from_pretrained(
46
+ pretrained_model_name_or_path, subfolder="vae"
47
+ )
48
+ unet = UNet2DConditionModel.from_pretrained(
49
+ pretrained_model_name_or_path, subfolder="unet"
50
+ )
51
+
52
+ # Congelar parâmetros do VAE e Text Encoder
53
+ vae.requires_grad_(False)
54
+ text_encoder.requires_grad_(False)
55
+
56
+ # Configurar LoRA
57
+ # Adicionar adaptadores LoRA ao UNet
58
+ # A função `add_adapter` do diffusers já configura os módulos LoRA e os torna treináveis.
59
+ unet.add_adapter(DiffusersLoRAAttnProcessor)
60
+
61
+ # Otimizador
62
+ # Apenas os parâmetros do LoRA devem ser treináveis
63
+ # O `add_adapter` já faz isso, então podemos simplesmente pegar os parâmetros treináveis do UNet.
64
+ lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
65
+
66
+ optimizer = torch.optim.AdamW(
67
+ lora_parameters,
68
+ lr=learning_rate,
69
+ )
70
+
71
+ # Scheduler
72
+ lr_scheduler = get_scheduler(
73
+ "constant",
74
+ optimizer=optimizer,
75
+ num_warmup_steps=0,
76
+ num_training_steps=num_epochs * len(os.listdir(instance_data_dir)),
77
+ )
78
+
79
+ # Dataset e DataLoader (simplificado para o exemplo)
80
+ class DreamBoothDataset(Dataset):
81
+ def __init__(self, instance_data_root, tokenizer, size=512, train_prompt="a photo of sks dog"):
82
+ self.instance_data_root = instance_data_root
83
+ self.tokenizer = tokenizer
84
+ self.size = size
85
+ self.train_prompt = train_prompt
86
+ self.instance_images_path = [os.path.join(instance_data_root, file_path) for file_path in os.listdir(instance_data_root) if file_path.endswith((".png", ".jpg", ".jpeg"))]
87
+ self.transform = transforms.Compose(
88
+ [
89
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
90
+ transforms.CenterCrop(size),
91
+ transforms.ToTensor(),
92
+ transforms.Normalize([0.5], [0.5]),
93
+ ]
94
+ )
95
+
96
+ def __len__(self):
97
+ return len(self.instance_images_path)
98
+
99
+ def __getitem__(self, index):
100
+ instance_image = Image.open(self.instance_images_path[index])
101
+ if not instance_image.mode == "RGB":
102
+ instance_image = instance_image.convert("RGB")
103
+ example = {}
104
+ example["instance_images"] = self.transform(instance_image)
105
+ example["instance_prompt_ids"] = self.tokenizer(self.train_prompt,
106
+ truncation=True,
107
+ padding="max_length",
108
+ max_length=self.tokenizer.model_max_length,
109
+ return_tensors="pt",
110
+ ).input_ids[0]
111
+ return example
112
+
113
+ train_dataset = DreamBoothDataset(instance_data_dir, tokenizer, resolution, train_prompt)
114
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
115
+
116
+ # Preparar para treinamento com Accelerator
117
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
118
+ unet, optimizer, train_dataloader, lr_scheduler
119
+ )
120
+
121
+ # Loop de treinamento
122
+ for epoch in range(num_epochs):
123
+ unet.train()
124
+ for step, batch in enumerate(train_dataloader):
125
+ with accelerator.accumulate(unet):
126
+ # Forward pass
127
+ latents = vae.encode(batch["instance_images"]).latent_dist.sample()
128
+ latents = latents * vae.config.scaling_factor
129
+
130
+ noise = torch.randn_like(latents)
131
+ timesteps = torch.randint(0, 1000, (batch_size,), device=latents.device).long()
132
+
133
+ noisy_latents = DDPMScheduler().add_noise(latents, noise, timesteps)
134
+
135
+ encoder_hidden_states = text_encoder(batch["instance_prompt_ids"])[0]
136
+
137
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
138
+
139
+ # Calcular perda
140
+ loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
141
+
142
+ # Backward pass
143
+ accelerator.backward(loss)
144
+ optimizer.step()
145
+ lr_scheduler.step()
146
+ optimizer.zero_grad()
147
+
148
+ accelerator.log({"loss": loss.item()}, step=epoch * len(train_dataloader) + step)
149
+ print(f"Epoch {epoch}, Step {step}, Loss: {loss.item()}")
150
+
151
+ # Salvar o modelo treinado
152
+ # Salvar apenas os pesos LoRA
153
+ lora_state_dict = {}
154
+ for name, param in unet.named_parameters():
155
+ if "lora" in name:
156
+ lora_state_dict[name] = param
157
+
158
+ lora_path = os.path.join(output_dir, "lora_model.safetensors")
159
+ # Usar safetensors para salvar o modelo
160
+ save_file(lora_state_dict, lora_path)
161
+
162
+ return lora_path
163
+
164
+
165
+ def run_training(
166
+ dataset_zip_file,
167
+ resolution,
168
+ learning_rate,
169
+ batch_size,
170
+ num_epochs,
171
+ train_prompt,
172
+ ):
173
+ if dataset_zip_file is None:
174
+ return "Por favor, faça o upload de um arquivo ZIP com seu dataset.", None
175
+
176
+ # Limpar diretórios anteriores
177
+ if os.path.exists("./data/dataset"):
178
+ shutil.rmtree("./data/dataset")
179
+ if os.path.exists("./outputs"):
180
+ shutil.rmtree("./outputs")
181
+ os.makedirs("./data/dataset", exist_ok=True)
182
+ os.makedirs("./outputs", exist_ok=True)
183
+
184
+ # Salvar e extrair o dataset
185
+ dataset_dir = "./data/dataset"
186
+ # O objeto dataset_zip_file do Gradio tem um atributo .name que é o caminho do arquivo temporário
187
+ zip_path = dataset_zip_file.name
188
+
189
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
190
+ zip_ref.extractall(dataset_dir)
191
+
192
+ # Iniciar treinamento
193
+ output_dir = "./outputs"
194
+
195
+ lora_model_path = train_lora(
196
+ instance_data_dir=dataset_dir,
197
+ output_dir=output_dir,
198
+ resolution=resolution,
199
+ learning_rate=learning_rate,
200
+ batch_size=batch_size,
201
+ num_epochs=num_epochs,
202
+ train_prompt=train_prompt,
203
+ )
204
+
205
+ return f"Treinamento concluído! Modelo salvo em: {lora_model_path}", lora_model_path
206
+
207
+
208
+ with gr.Blocks() as demo:
209
+ gr.Markdown("# Treinador LoRA para Hugging Face Spaces")
210
+
211
+ with gr.Row():
212
+ with gr.Column():
213
+ dataset_zip = gr.File(label="Upload do Dataset (ZIP)", file_types=[".zip"])
214
+ resolution = gr.Slider(minimum=128, maximum=1024, value=512, step=128, label="Resolução da Imagem")
215
+ learning_rate = gr.Number(value=1e-4, label="Learning Rate")
216
+ batch_size = gr.Slider(minimum=1, maximum=8, value=1, step=1, label="Batch Size")
217
+ num_epochs = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Número de Epochs")
218
+ train_prompt = gr.Textbox(label="Prompt de Treinamento (ex: a photo of sks dog)", value="a photo of sks dog")
219
+ train_button = gr.Button("Iniciar Treinamento")
220
+
221
+ with gr.Column():
222
+ output_text = gr.Textbox(label="Status do Treinamento")
223
+ output_file = gr.File(label="Modelo LoRA Treinado")
224
+
225
+ train_button.click(
226
+ run_training,
227
+ inputs=[
228
+ dataset_zip,
229
+ resolution,
230
+ learning_rate,
231
+ batch_size,
232
+ num_epochs,
233
+ train_prompt,
234
+ ],
235
+ outputs=[output_text, output_file],
236
+ )
237
+
238
+
239
+ if __name__ == "__main__":
240
+ demo.launch(debug=True)
241
+
242
+
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ accelerate
3
+ transformers
4
+ diffusers
5
+ safetensors
6
+ xformers
7
+ gradio
8
+ Pillow
9
+ datasets
10
+
11
+
12
+ torchvision
13
+