Allex21 commited on
Commit
3d74922
·
verified ·
1 Parent(s): e9330ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -130
app.py CHANGED
@@ -1,24 +1,67 @@
1
-
2
  import gradio as gr
3
  import os
4
  import torch
5
  from accelerate import Accelerator
6
- from accelerate.utils import set_seed
7
- from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, StableDiffusionPipeline
8
  from diffusers.optimization import get_scheduler
9
- from diffusers.training_utils import EMAModel
10
- from diffusers.models.attention_processor import LoRAAttnProcessor as DiffusersLoRAAttnProcessor
11
- from huggingface_hub import create_repo, upload_folder
12
  from PIL import Image
13
  from torch.utils.data import Dataset
14
  from torchvision import transforms
15
- from tqdm.auto import tqdm
16
  from transformers import CLIPTextModel, CLIPTokenizer
17
  import zipfile
18
  import shutil
19
  from safetensors.torch import save_file
20
-
21
- # Placeholder para o script de treinamento
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def train_lora(
23
  instance_data_dir: str,
24
  output_dir: str,
@@ -35,40 +78,57 @@ def train_lora(
35
  mixed_precision="fp16",
36
  )
37
 
38
- # Carregar tokenizer e modelo base
39
- tokenizer = CLIPTokenizer.from_pretrained(
40
- pretrained_model_name_or_path, subfolder="tokenizer"
41
- )
42
- text_encoder = CLIPTextModel.from_pretrained(
43
- pretrained_model_name_or_path, subfolder="text_encoder"
44
- )
45
- vae = AutoencoderKL.from_pretrained(
46
- pretrained_model_name_or_path, subfolder="vae"
47
- )
48
- unet = UNet2DConditionModel.from_pretrained(
49
- pretrained_model_name_or_path, subfolder="unet"
50
- )
51
 
52
- # Congelar parâmetros do VAE e Text Encoder
53
  vae.requires_grad_(False)
54
  text_encoder.requires_grad_(False)
55
-
56
- # Configurar LoRA
57
- # Adicionar adaptadores LoRA ao UNet
58
- # A função `add_adapter` do diffusers já configura os módulos LoRA e os torna treináveis.
59
- unet.add_adapter(DiffusersLoRAAttnProcessor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # Otimizador
62
- # Apenas os parâmetros do LoRA devem ser treináveis
63
- # O `add_adapter` já faz isso, então podemos simplesmente pegar os parâmetros treináveis do UNet.
64
- lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
65
 
66
- optimizer = torch.optim.AdamW(
67
- lora_parameters,
68
- lr=learning_rate,
69
- )
70
 
71
- # Scheduler
72
  lr_scheduler = get_scheduler(
73
  "constant",
74
  optimizer=optimizer,
@@ -76,92 +136,61 @@ def train_lora(
76
  num_training_steps=num_epochs * len(os.listdir(instance_data_dir)),
77
  )
78
 
79
- # Dataset e DataLoader (simplificado para o exemplo)
80
- class DreamBoothDataset(Dataset):
81
- def __init__(self, instance_data_root, tokenizer, size=512, train_prompt="a photo of sks dog"):
82
- self.instance_data_root = instance_data_root
83
- self.tokenizer = tokenizer
84
- self.size = size
85
- self.train_prompt = train_prompt
86
- self.instance_images_path = [os.path.join(instance_data_root, file_path) for file_path in os.listdir(instance_data_root) if file_path.endswith((".png", ".jpg", ".jpeg"))]
87
- self.transform = transforms.Compose(
88
- [
89
- transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
90
- transforms.CenterCrop(size),
91
- transforms.ToTensor(),
92
- transforms.Normalize([0.5], [0.5]),
93
- ]
94
- )
95
-
96
- def __len__(self):
97
- return len(self.instance_images_path)
98
-
99
- def __getitem__(self, index):
100
- instance_image = Image.open(self.instance_images_path[index])
101
- if not instance_image.mode == "RGB":
102
- instance_image = instance_image.convert("RGB")
103
- example = {}
104
- example["instance_images"] = self.transform(instance_image)
105
- example["instance_prompt_ids"] = self.tokenizer(self.train_prompt,
106
- truncation=True,
107
- padding="max_length",
108
- max_length=self.tokenizer.model_max_length,
109
- return_tensors="pt",
110
- ).input_ids[0]
111
- return example
112
-
113
  train_dataset = DreamBoothDataset(instance_data_dir, tokenizer, resolution, train_prompt)
114
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
115
 
116
- # Preparar para treinamento com Accelerator
117
  unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
118
  unet, optimizer, train_dataloader, lr_scheduler
119
  )
120
 
121
- # Loop de treinamento
 
122
  for epoch in range(num_epochs):
123
  unet.train()
124
  for step, batch in enumerate(train_dataloader):
125
  with accelerator.accumulate(unet):
126
- # Forward pass
127
- latents = vae.encode(batch["instance_images"]).latent_dist.sample()
 
128
  latents = latents * vae.config.scaling_factor
129
 
130
- noise = torch.randn_like(latents)
131
- timesteps = torch.randint(0, 1000, (batch_size,), device=latents.device).long()
132
 
133
- noisy_latents = DDPMScheduler().add_noise(latents, noise, timesteps)
134
 
135
- encoder_hidden_states = text_encoder(batch["instance_prompt_ids"])[0]
136
 
 
137
  model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
138
 
139
- # Calcular perda
140
  loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
141
 
142
- # Backward pass
143
  accelerator.backward(loss)
144
  optimizer.step()
145
  lr_scheduler.step()
146
  optimizer.zero_grad()
147
 
148
- accelerator.log({"loss": loss.item()}, step=epoch * len(train_dataloader) + step)
149
- print(f"Epoch {epoch}, Step {step}, Loss: {loss.item()}")
150
 
151
- # Salvar o modelo treinado
152
- # Salvar apenas os pesos LoRA
153
- lora_state_dict = {}
154
- for name, param in unet.named_parameters():
155
- if "lora" in name:
156
- lora_state_dict[name] = param
157
 
158
  lora_path = os.path.join(output_dir, "lora_model.safetensors")
159
- # Usar safetensors para salvar o modelo
160
  save_file(lora_state_dict, lora_path)
161
 
162
  return lora_path
163
 
164
-
165
  def run_training(
166
  dataset_zip_file,
167
  resolution,
@@ -181,62 +210,51 @@ def run_training(
181
  os.makedirs("./data/dataset", exist_ok=True)
182
  os.makedirs("./outputs", exist_ok=True)
183
 
184
- # Salvar e extrair o dataset
185
  dataset_dir = "./data/dataset"
186
- # O objeto dataset_zip_file do Gradio tem um atributo .name que é o caminho do arquivo temporário
187
  zip_path = dataset_zip_file.name
188
-
189
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
190
  zip_ref.extractall(dataset_dir)
191
 
192
- # Iniciar treinamento
193
  output_dir = "./outputs"
194
-
195
- lora_model_path = train_lora(
196
- instance_data_dir=dataset_dir,
197
- output_dir=output_dir,
198
- resolution=resolution,
199
- learning_rate=learning_rate,
200
- batch_size=batch_size,
201
- num_epochs=num_epochs,
202
- train_prompt=train_prompt,
203
- )
204
-
205
- return f"Treinamento concluído! Modelo salvo em: {lora_model_path}", lora_model_path
206
-
207
-
 
208
  with gr.Blocks() as demo:
209
- gr.Markdown("# Treinador LoRA para Hugging Face Spaces")
210
 
211
  with gr.Row():
212
  with gr.Column():
213
- dataset_zip = gr.File(label="Upload do Dataset (ZIP)", file_types=[".zip"])
214
- resolution = gr.Slider(minimum=128, maximum=1024, value=512, step=128, label="Resolução da Imagem")
215
- learning_rate = gr.Number(value=1e-4, label="Learning Rate")
216
- batch_size = gr.Slider(minimum=1, maximum=8, value=1, step=1, label="Batch Size")
217
- num_epochs = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Número de Epochs")
218
- train_prompt = gr.Textbox(label="Prompt de Treinamento (ex: a photo of sks dog)", value="a photo of sks dog")
219
- train_button = gr.Button("Iniciar Treinamento")
220
 
221
  with gr.Column():
222
- output_text = gr.Textbox(label="Status do Treinamento")
223
- output_file = gr.File(label="Modelo LoRA Treinado")
224
 
225
  train_button.click(
226
  run_training,
227
- inputs=[
228
- dataset_zip,
229
- resolution,
230
- learning_rate,
231
- batch_size,
232
- num_epochs,
233
- train_prompt,
234
- ],
235
  outputs=[output_text, output_file],
236
  )
237
 
238
-
239
  if __name__ == "__main__":
240
- demo.launch(debug=True)
241
-
242
-
 
 
1
  import gradio as gr
2
  import os
3
  import torch
4
  from accelerate import Accelerator
5
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
 
6
  from diffusers.optimization import get_scheduler
 
 
 
7
  from PIL import Image
8
  from torch.utils.data import Dataset
9
  from torchvision import transforms
 
10
  from transformers import CLIPTextModel, CLIPTokenizer
11
  import zipfile
12
  import shutil
13
  from safetensors.torch import save_file
14
+ import torch.nn as nn
15
+
16
+ # Função para criar camadas LoRA
17
+ def create_lora_layers(module, rank=4):
18
+ if isinstance(module, nn.Linear):
19
+ lora_down = nn.Linear(module.in_features, rank, bias=False)
20
+ lora_up = nn.Linear(rank, module.out_features, bias=False)
21
+ nn.init.zeros_(lora_up.weight) # Inicialização zero para começar neutro
22
+ return lora_down, lora_up
23
+ return None, None
24
+
25
+ # Dataset simplificado
26
+ class DreamBoothDataset(Dataset):
27
+ def __init__(self, instance_data_root, tokenizer, size=512, train_prompt="a photo of sks dog"):
28
+ self.instance_data_root = instance_data_root
29
+ self.tokenizer = tokenizer
30
+ self.size = size
31
+ self.train_prompt = train_prompt
32
+ self.instance_images_path = [
33
+ os.path.join(instance_data_root, file_path)
34
+ for file_path in os.listdir(instance_data_root)
35
+ if file_path.endswith((".png", ".jpg", ".jpeg"))
36
+ ]
37
+ self.transform = transforms.Compose(
38
+ [
39
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
40
+ transforms.CenterCrop(size),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize([0.5], [0.5]),
43
+ ]
44
+ )
45
+
46
+ def __len__(self):
47
+ return len(self.instance_images_path)
48
+
49
+ def __getitem__(self, index):
50
+ instance_image = Image.open(self.instance_images_path[index])
51
+ if not instance_image.mode == "RGB":
52
+ instance_image = instance_image.convert("RGB")
53
+ example = {}
54
+ example["instance_images"] = self.transform(instance_image)
55
+ example["instance_prompt_ids"] = self.tokenizer(
56
+ self.train_prompt,
57
+ truncation=True,
58
+ padding="max_length",
59
+ max_length=self.tokenizer.model_max_length,
60
+ return_tensors="pt",
61
+ ).input_ids[0]
62
+ return example
63
+
64
+ # Função principal de treinamento
65
  def train_lora(
66
  instance_data_dir: str,
67
  output_dir: str,
 
78
  mixed_precision="fp16",
79
  )
80
 
81
+ # Carregar modelos
82
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
83
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
84
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
85
+ unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
 
 
 
 
 
 
 
 
86
 
87
+ # Congelar VAE e Text Encoder
88
  vae.requires_grad_(False)
89
  text_encoder.requires_grad_(False)
90
+ unet.requires_grad_(False)
91
+
92
+ # Injetar LoRA no UNet
93
+ lora_layers = []
94
+ for name, module in unet.named_modules():
95
+ if name.endswith("to_q") or name.endswith("to_k") or name.endswith("to_v") or name.endswith("to_out.0"):
96
+ lora_down, lora_up = create_lora_layers(module, rank=4)
97
+ if lora_down is not None:
98
+ module.lora_down = lora_down.to(module.weight.device)
99
+ module.lora_up = lora_up.to(module.weight.device)
100
+ lora_layers.extend([module.lora_down, module.lora_up])
101
+
102
+ # Guardar forward original
103
+ if not hasattr(module, "_original_forward"):
104
+ module._original_forward = module.forward
105
+
106
+ # Criar novo forward com LoRA
107
+ def forward_with_lora(self, x):
108
+ original_output = self._original_forward(x)
109
+ lora_output = self.lora_up(self.lora_down(x))
110
+ return original_output + lora_output
111
+
112
+ # Associar o novo forward ao módulo
113
+ import types
114
+ module.forward = types.MethodType(forward_with_lora, module)
115
+
116
+ # Liberar apenas parâmetros LoRA
117
+ for layer in lora_layers:
118
+ layer.requires_grad_(True)
119
+
120
+ # Coletar parâmetros treináveis
121
+ lora_parameters = []
122
+ for layer in lora_layers:
123
+ lora_parameters.extend(layer.parameters())
124
 
125
  # Otimizador
126
+ optimizer = torch.optim.AdamW(lora_parameters, lr=learning_rate)
 
 
127
 
128
+ # Scheduler de ruído
129
+ noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
 
 
130
 
131
+ # Scheduler de learning rate
132
  lr_scheduler = get_scheduler(
133
  "constant",
134
  optimizer=optimizer,
 
136
  num_training_steps=num_epochs * len(os.listdir(instance_data_dir)),
137
  )
138
 
139
+ # Dataset e DataLoader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  train_dataset = DreamBoothDataset(instance_data_dir, tokenizer, resolution, train_prompt)
141
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
142
 
143
+ # Preparar com Accelerator
144
  unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
145
  unet, optimizer, train_dataloader, lr_scheduler
146
  )
147
 
148
+ # Treinamento
149
+ global_step = 0
150
  for epoch in range(num_epochs):
151
  unet.train()
152
  for step, batch in enumerate(train_dataloader):
153
  with accelerator.accumulate(unet):
154
+ # Preparar dados
155
+ pixel_values = batch["instance_images"].to(accelerator.device)
156
+ latents = vae.encode(pixel_values).latent_dist.sample()
157
  latents = latents * vae.config.scaling_factor
158
 
159
+ noise = torch.randn_like(latents).to(accelerator.device)
160
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long()
161
 
162
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
163
 
164
+ encoder_hidden_states = text_encoder(batch["instance_prompt_ids"].to(accelerator.device))[0]
165
 
166
+ # Predição
167
  model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
168
 
169
+ # Perda
170
  loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
171
 
172
+ # Backprop
173
  accelerator.backward(loss)
174
  optimizer.step()
175
  lr_scheduler.step()
176
  optimizer.zero_grad()
177
 
178
+ global_step += 1
179
+ print(f"Epoch {epoch + 1}/{num_epochs}, Step {step + 1}, Loss: {loss.item():.6f}")
180
 
181
+ # Salvar LoRA
182
+ lora_state_dict = {}
183
+ for name, module in unet.named_modules():
184
+ if hasattr(module, "lora_down") and hasattr(module, "lora_up"):
185
+ lora_state_dict[f"{name}.lora_down.weight"] = module.lora_down.weight
186
+ lora_state_dict[f"{name}.lora_up.weight"] = module.lora_up.weight
187
 
188
  lora_path = os.path.join(output_dir, "lora_model.safetensors")
 
189
  save_file(lora_state_dict, lora_path)
190
 
191
  return lora_path
192
 
193
+ # Função para Gradio
194
  def run_training(
195
  dataset_zip_file,
196
  resolution,
 
210
  os.makedirs("./data/dataset", exist_ok=True)
211
  os.makedirs("./outputs", exist_ok=True)
212
 
213
+ # Extrair dataset
214
  dataset_dir = "./data/dataset"
 
215
  zip_path = dataset_zip_file.name
 
216
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
217
  zip_ref.extractall(dataset_dir)
218
 
219
+ # Treinar
220
  output_dir = "./outputs"
221
+ try:
222
+ lora_model_path = train_lora(
223
+ instance_data_dir=dataset_dir,
224
+ output_dir=output_dir,
225
+ resolution=resolution,
226
+ learning_rate=learning_rate,
227
+ batch_size=batch_size,
228
+ num_epochs=num_epochs,
229
+ train_prompt=train_prompt,
230
+ )
231
+ return f"✅ Treinamento concluído! Modelo salvo em: {lora_model_path}", lora_model_path
232
+ except Exception as e:
233
+ return f"❌ Erro durante o treinamento: {str(e)}", None
234
+
235
+ # Interface Gradio
236
  with gr.Blocks() as demo:
237
+ gr.Markdown("# 🧠 Treinador LoRA para Stable Diffusion")
238
 
239
  with gr.Row():
240
  with gr.Column():
241
+ dataset_zip = gr.File(label="📁 Upload do Dataset (ZIP)", file_types=[".zip"])
242
+ resolution = gr.Slider(minimum=128, maximum=1024, value=512, step=128, label="📏 Resolução da Imagem")
243
+ learning_rate = gr.Number(value=1e-4, label="📈 Learning Rate")
244
+ batch_size = gr.Slider(minimum=1, maximum=8, value=1, step=1, label="📦 Batch Size")
245
+ num_epochs = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="🔁 Número de Epochs")
246
+ train_prompt = gr.Textbox(label="📝 Prompt de Treinamento (ex: a photo of sks dog)", value="a photo of sks dog")
247
+ train_button = gr.Button("🚀 Iniciar Treinamento", variant="primary")
248
 
249
  with gr.Column():
250
+ output_text = gr.Textbox(label="📊 Status do Treinamento", lines=5)
251
+ output_file = gr.File(label="💾 Modelo LoRA Treinado")
252
 
253
  train_button.click(
254
  run_training,
255
+ inputs=[dataset_zip, resolution, learning_rate, batch_size, num_epochs, train_prompt],
 
 
 
 
 
 
 
256
  outputs=[output_text, output_file],
257
  )
258
 
 
259
  if __name__ == "__main__":
260
+ demo.launch(debug=True)