Allex21 commited on
Commit
8da6913
·
verified ·
1 Parent(s): 665cbcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -37
app.py CHANGED
@@ -21,7 +21,7 @@ from diffusers import (
21
  AutoencoderKL
22
  )
23
  from transformers import CLIPTextModel, CLIPTokenizer
24
- from peft import LoraConfig, get_peft_model, TaskType
25
  import logging
26
 
27
  # Configurar logging
@@ -86,30 +86,6 @@ class LoRAImageTrainer:
86
  logger.error(f"Erro ao carregar modelo {model_name}: {str(e)}")
87
  raise e
88
 
89
- def create_lora_config(self,
90
- r: int = 16,
91
- lora_alpha: int = 32,
92
- lora_dropout: float = 0.1,
93
- target_modules: Optional[List[str]] = None) -> LoraConfig:
94
- """Cria configuração LoRA otimizada para modelos de difusão."""
95
-
96
- if target_modules is None:
97
- # Módulos padrão para UNet do Stable Diffusion
98
- target_modules = [
99
- "to_k", "to_q", "to_v", "to_out.0",
100
- "proj_in", "proj_out",
101
- "ff.net.0.proj", "ff.net.2"
102
- ]
103
-
104
- return LoraConfig(
105
- r=r,
106
- lora_alpha=lora_alpha,
107
- target_modules=target_modules,
108
- lora_dropout=lora_dropout,
109
- bias="none",
110
- task_type=TaskType.CAUSAL_LM, # ✅ CORREÇÃO PRINCIPAL: DIFFUSION → CAUSAL_LM
111
- )
112
-
113
  def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:
114
  """Prepara dataset de imagens para treinamento."""
115
  dataset = []
@@ -171,7 +147,7 @@ class LoRAImageTrainer:
171
  learning_rate: float = 1e-4,
172
  batch_size: int = 1,
173
  resolution: int = 512) -> None:
174
- """TREINAMENTO REAL DE LoRA PARA IMAGENS."""
175
 
176
  try:
177
  # Atualizar status
@@ -192,14 +168,25 @@ class LoRAImageTrainer:
192
  text_encoder.requires_grad_(False)
193
  vae.requires_grad_(False)
194
 
195
- # Configurar LoRA no UNet
196
- lora_config = self.create_lora_config(r, lora_alpha, lora_dropout)
197
- unet_lora = get_peft_model(unet, lora_config)
198
- unet_lora.train()
199
- unet_lora.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  # Otimizador
202
- optimizer = torch.optim.AdamW(unet_lora.parameters(), lr=learning_rate)
203
 
204
  # Preparar scheduler para treinamento
205
  self.training_jobs[job_id]["status"] = "preparing_data"
@@ -249,7 +236,7 @@ class LoRAImageTrainer:
249
 
250
  # Forward pass
251
  encoder_hidden_states = text_encoder(input_ids)[0]
252
- noise_pred = unet_lora(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
253
 
254
  # Calcular perda
255
  loss = torch.nn.functional.mse_loss(noise_pred, noise)
@@ -274,17 +261,17 @@ class LoRAImageTrainer:
274
  output_dir = f"./lora_models/{job_id}"
275
  os.makedirs(output_dir, exist_ok=True)
276
 
277
- # Salvar apenas os pesos LoRA do UNet
278
- unet_lora.save_pretrained(output_dir)
279
 
280
  # Criar adapter_config.json
281
  lora_config_dict = {
282
  "r": r,
283
  "lora_alpha": lora_alpha,
284
- "target_modules": lora_config.target_modules,
285
  "lora_dropout": lora_dropout,
286
  "bias": "none",
287
- "task_type": "CAUSAL_LM", # CORREÇÃO AQUI TAMBÉM: DIFFUSION CAUSAL_LM
288
  "base_model_name": model_name,
289
  "training_info": {
290
  "num_epochs": num_epochs,
 
21
  AutoencoderKL
22
  )
23
  from transformers import CLIPTextModel, CLIPTokenizer
24
+ from peft import LoraConfig
25
  import logging
26
 
27
  # Configurar logging
 
86
  logger.error(f"Erro ao carregar modelo {model_name}: {str(e)}")
87
  raise e
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:
90
  """Prepara dataset de imagens para treinamento."""
91
  dataset = []
 
147
  learning_rate: float = 1e-4,
148
  batch_size: int = 1,
149
  resolution: int = 512) -> None:
150
+ """TREINAMENTO REAL DE LoRA PARA IMAGENS - CORRIGIDO PARA DIFFUSERS + PEFT."""
151
 
152
  try:
153
  # Atualizar status
 
168
  text_encoder.requires_grad_(False)
169
  vae.requires_grad_(False)
170
 
171
+ # Criar configuração LoRA
172
+ lora_config = LoraConfig(
173
+ r=r,
174
+ lora_alpha=lora_alpha,
175
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
176
+ lora_dropout=lora_dropout,
177
+ bias="none"
178
+ )
179
+
180
+ # Aplicar LoRA ao UNet manualmente, sem usar get_peft_model diretamente
181
+ unet.add_adapter(lora_config, adapter_name="default")
182
+
183
+ # Ativar o adaptador
184
+ unet.set_adapter("default")
185
+ unet.train()
186
+ unet.to(self.device)
187
 
188
  # Otimizador
189
+ optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
190
 
191
  # Preparar scheduler para treinamento
192
  self.training_jobs[job_id]["status"] = "preparing_data"
 
236
 
237
  # Forward pass
238
  encoder_hidden_states = text_encoder(input_ids)[0]
239
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
240
 
241
  # Calcular perda
242
  loss = torch.nn.functional.mse_loss(noise_pred, noise)
 
261
  output_dir = f"./lora_models/{job_id}"
262
  os.makedirs(output_dir, exist_ok=True)
263
 
264
+ # Salvar apenas os adaptadores LoRA
265
+ unet.save_pretrained(output_dir)
266
 
267
  # Criar adapter_config.json
268
  lora_config_dict = {
269
  "r": r,
270
  "lora_alpha": lora_alpha,
271
+ "target_modules": ["to_k", "to_q", "to_v", "to_out.0"],
272
  "lora_dropout": lora_dropout,
273
  "bias": "none",
274
+ "task_type": "CAUSAL_LM", # Mantido por compatibilidade, mas não é usado
275
  "base_model_name": model_name,
276
  "training_info": {
277
  "num_epochs": num_epochs,