Update app.py
Browse files
app.py
CHANGED
|
@@ -35,14 +35,20 @@ class LoRAImageTrainer:
|
|
| 35 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
self.training_jobs = {}
|
| 37 |
self.models_cache = {}
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
|
| 41 |
def _save_job_state(self, job_id: str):
|
| 42 |
"""Salva o estado do job em disco."""
|
| 43 |
job_file = Path(f"./jobs/{job_id}.json")
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
def _load_job_state(self, job_id: str) -> Optional[Dict]:
|
| 48 |
"""Carrega o estado do job do disco."""
|
|
@@ -50,9 +56,13 @@ class LoRAImageTrainer:
|
|
| 50 |
if job_file.exists():
|
| 51 |
try:
|
| 52 |
with open(job_file, "r") as f:
|
| 53 |
-
|
|
|
|
|
|
|
| 54 |
except Exception as e:
|
| 55 |
logger.error(f"Erro ao carregar job {job_id}: {e}")
|
|
|
|
|
|
|
| 56 |
return None
|
| 57 |
|
| 58 |
def get_training_status(self, job_id: str) -> Dict[str, Any]:
|
|
@@ -177,8 +187,10 @@ class LoRAImageTrainer:
|
|
| 177 |
|
| 178 |
self.training_jobs[job_id]["status"] = "loading_model"
|
| 179 |
self.training_jobs[job_id]["progress"] = 5
|
| 180 |
-
|
|
|
|
| 181 |
self._save_job_state(job_id)
|
|
|
|
| 182 |
|
| 183 |
pipeline = self.load_base_model(model_name)
|
| 184 |
unet = pipeline.unet
|
|
@@ -191,9 +203,10 @@ class LoRAImageTrainer:
|
|
| 191 |
text_encoder.requires_grad_(False)
|
| 192 |
vae.requires_grad_(False)
|
| 193 |
|
| 194 |
-
# ✅
|
| 195 |
if hasattr(unet, "peft_config") and "default" in unet.peft_config:
|
| 196 |
unet.delete_adapter("default")
|
|
|
|
| 197 |
|
| 198 |
lora_config = LoraConfig(
|
| 199 |
r=r,
|
|
@@ -206,16 +219,19 @@ class LoRAImageTrainer:
|
|
| 206 |
unet.add_adapter(lora_config, adapter_name="default")
|
| 207 |
unet.set_adapter("default")
|
| 208 |
|
| 209 |
-
# ✅
|
| 210 |
unet.requires_grad_(False)
|
|
|
|
| 211 |
for name, param in unet.named_parameters():
|
| 212 |
if "lora_" in name:
|
| 213 |
param.requires_grad = True
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
unet.train()
|
| 216 |
unet.to(self.device)
|
| 217 |
|
| 218 |
-
# Otimizador só nos parâmetros que requerem gradiente
|
| 219 |
optimizer = torch.optim.AdamW([p for p in unet.parameters() if p.requires_grad], lr=learning_rate)
|
| 220 |
|
| 221 |
self.training_jobs[job_id]["status"] = "preparing_data"
|
|
@@ -232,8 +248,10 @@ class LoRAImageTrainer:
|
|
| 232 |
current_step = 0
|
| 233 |
|
| 234 |
self.training_jobs[job_id]["status"] = "training"
|
| 235 |
-
|
|
|
|
| 236 |
self._save_job_state(job_id)
|
|
|
|
| 237 |
|
| 238 |
for epoch in range(num_epochs):
|
| 239 |
for item in dataset:
|
|
@@ -271,8 +289,9 @@ class LoRAImageTrainer:
|
|
| 271 |
|
| 272 |
if current_step % max(1, len(dataset)//2) == 0:
|
| 273 |
log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
|
| 274 |
-
self.training_jobs[job_id]["logs"].append(
|
| 275 |
self._save_job_state(job_id)
|
|
|
|
| 276 |
|
| 277 |
self.training_jobs[job_id]["status"] = "saving"
|
| 278 |
self.training_jobs[job_id]["progress"] = 95
|
|
@@ -281,7 +300,6 @@ class LoRAImageTrainer:
|
|
| 281 |
output_dir = f"./lora_models/{job_id}"
|
| 282 |
os.makedirs(output_dir, exist_ok=True)
|
| 283 |
|
| 284 |
-
# ✅ SALVAR APENAS O LORA
|
| 285 |
unet.save_pretrained(
|
| 286 |
output_dir,
|
| 287 |
safe_serialization=True,
|
|
@@ -319,10 +337,10 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
|
| 319 |
self.training_jobs[job_id]["progress"] = 100
|
| 320 |
self.training_jobs[job_id]["model_path"] = output_dir
|
| 321 |
self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
|
| 322 |
-
|
|
|
|
| 323 |
self._save_job_state(job_id)
|
| 324 |
-
|
| 325 |
-
logger.info(f"Treinamento LoRA concluído para job {job_id}")
|
| 326 |
|
| 327 |
except Exception as e:
|
| 328 |
error_msg = f"Erro no treinamento: {str(e)}"
|
|
@@ -386,7 +404,8 @@ Data: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
|
| 386 |
"r": config.get("r", "Unknown"),
|
| 387 |
"created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
|
| 388 |
})
|
| 389 |
-
except:
|
|
|
|
| 390 |
models.append({
|
| 391 |
"id": model_dir.name,
|
| 392 |
"path": str(model_dir),
|
|
@@ -519,6 +538,11 @@ def create_gradio_interface():
|
|
| 519 |
return interface
|
| 520 |
|
| 521 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
interface = create_gradio_interface()
|
| 523 |
interface.launch(
|
| 524 |
server_name="0.0.0.0",
|
|
|
|
| 35 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
self.training_jobs = {}
|
| 37 |
self.models_cache = {}
|
| 38 |
+
# ✅ Garantir que as pastas existam no diretório atual
|
| 39 |
+
os.makedirs("./lora_models", exist_ok=True)
|
| 40 |
+
os.makedirs("./jobs", exist_ok=True)
|
| 41 |
+
logger.info("Pastas ./lora_models e ./jobs criadas com sucesso.")
|
| 42 |
|
| 43 |
def _save_job_state(self, job_id: str):
|
| 44 |
"""Salva o estado do job em disco."""
|
| 45 |
job_file = Path(f"./jobs/{job_id}.json")
|
| 46 |
+
try:
|
| 47 |
+
with open(job_file, "w") as f:
|
| 48 |
+
json.dump(self.training_jobs[job_id], f, indent=2, default=str)
|
| 49 |
+
logger.info(f"Estado do job {job_id} salvo em disco.")
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.error(f"Erro ao salvar job {job_id}: {e}")
|
| 52 |
|
| 53 |
def _load_job_state(self, job_id: str) -> Optional[Dict]:
|
| 54 |
"""Carrega o estado do job do disco."""
|
|
|
|
| 56 |
if job_file.exists():
|
| 57 |
try:
|
| 58 |
with open(job_file, "r") as f:
|
| 59 |
+
loaded_data = json.load(f)
|
| 60 |
+
logger.info(f"Estado do job {job_id} carregado do disco.")
|
| 61 |
+
return loaded_data
|
| 62 |
except Exception as e:
|
| 63 |
logger.error(f"Erro ao carregar job {job_id}: {e}")
|
| 64 |
+
else:
|
| 65 |
+
logger.warning(f"Arquivo do job {job_id} não encontrado em disco.")
|
| 66 |
return None
|
| 67 |
|
| 68 |
def get_training_status(self, job_id: str) -> Dict[str, Any]:
|
|
|
|
| 187 |
|
| 188 |
self.training_jobs[job_id]["status"] = "loading_model"
|
| 189 |
self.training_jobs[job_id]["progress"] = 5
|
| 190 |
+
log_msg = f"{datetime.now().strftime('%H:%M:%S')} - Carregando modelo base: {model_name}"
|
| 191 |
+
self.training_jobs[job_id]["logs"].append(log_msg)
|
| 192 |
self._save_job_state(job_id)
|
| 193 |
+
logger.info(log_msg)
|
| 194 |
|
| 195 |
pipeline = self.load_base_model(model_name)
|
| 196 |
unet = pipeline.unet
|
|
|
|
| 203 |
text_encoder.requires_grad_(False)
|
| 204 |
vae.requires_grad_(False)
|
| 205 |
|
| 206 |
+
# ✅ Remover adaptador existente
|
| 207 |
if hasattr(unet, "peft_config") and "default" in unet.peft_config:
|
| 208 |
unet.delete_adapter("default")
|
| 209 |
+
logger.info("Adaptador 'default' removido com sucesso.")
|
| 210 |
|
| 211 |
lora_config = LoraConfig(
|
| 212 |
r=r,
|
|
|
|
| 219 |
unet.add_adapter(lora_config, adapter_name="default")
|
| 220 |
unet.set_adapter("default")
|
| 221 |
|
| 222 |
+
# ✅ Ativar apenas parâmetros do LoRA
|
| 223 |
unet.requires_grad_(False)
|
| 224 |
+
trainable_params = 0
|
| 225 |
for name, param in unet.named_parameters():
|
| 226 |
if "lora_" in name:
|
| 227 |
param.requires_grad = True
|
| 228 |
+
trainable_params += 1
|
| 229 |
+
|
| 230 |
+
logger.info(f"Número de parâmetros treináveis (LoRA): {trainable_params}")
|
| 231 |
|
| 232 |
unet.train()
|
| 233 |
unet.to(self.device)
|
| 234 |
|
|
|
|
| 235 |
optimizer = torch.optim.AdamW([p for p in unet.parameters() if p.requires_grad], lr=learning_rate)
|
| 236 |
|
| 237 |
self.training_jobs[job_id]["status"] = "preparing_data"
|
|
|
|
| 248 |
current_step = 0
|
| 249 |
|
| 250 |
self.training_jobs[job_id]["status"] = "training"
|
| 251 |
+
log_msg = f"{datetime.now().strftime('%H:%M:%S')} - Iniciando treinamento real..."
|
| 252 |
+
self.training_jobs[job_id]["logs"].append(log_msg)
|
| 253 |
self._save_job_state(job_id)
|
| 254 |
+
logger.info(log_msg)
|
| 255 |
|
| 256 |
for epoch in range(num_epochs):
|
| 257 |
for item in dataset:
|
|
|
|
| 289 |
|
| 290 |
if current_step % max(1, len(dataset)//2) == 0:
|
| 291 |
log_msg = f"Época {epoch+1}, Step {current_step} - Loss: {loss.item():.4f}"
|
| 292 |
+
self.training_jobs[job_id]["logs"].append(log_msg)
|
| 293 |
self._save_job_state(job_id)
|
| 294 |
+
logger.info(log_msg)
|
| 295 |
|
| 296 |
self.training_jobs[job_id]["status"] = "saving"
|
| 297 |
self.training_jobs[job_id]["progress"] = 95
|
|
|
|
| 300 |
output_dir = f"./lora_models/{job_id}"
|
| 301 |
os.makedirs(output_dir, exist_ok=True)
|
| 302 |
|
|
|
|
| 303 |
unet.save_pretrained(
|
| 304 |
output_dir,
|
| 305 |
safe_serialization=True,
|
|
|
|
| 337 |
self.training_jobs[job_id]["progress"] = 100
|
| 338 |
self.training_jobs[job_id]["model_path"] = output_dir
|
| 339 |
self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
|
| 340 |
+
log_msg = f"{datetime.now().strftime('%H:%M:%S')} - ✅ Treinamento concluído! LoRA salvo em {output_dir}"
|
| 341 |
+
self.training_jobs[job_id]["logs"].append(log_msg)
|
| 342 |
self._save_job_state(job_id)
|
| 343 |
+
logger.info(log_msg)
|
|
|
|
| 344 |
|
| 345 |
except Exception as e:
|
| 346 |
error_msg = f"Erro no treinamento: {str(e)}"
|
|
|
|
| 404 |
"r": config.get("r", "Unknown"),
|
| 405 |
"created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
|
| 406 |
})
|
| 407 |
+
except Exception as e:
|
| 408 |
+
logger.error(f"Erro ao ler config de {model_dir.name}: {e}")
|
| 409 |
models.append({
|
| 410 |
"id": model_dir.name,
|
| 411 |
"path": str(model_dir),
|
|
|
|
| 538 |
return interface
|
| 539 |
|
| 540 |
if __name__ == "__main__":
|
| 541 |
+
# ✅ Garantir que os diretórios existam
|
| 542 |
+
os.makedirs("./lora_models", exist_ok=True)
|
| 543 |
+
os.makedirs("./jobs", exist_ok=True)
|
| 544 |
+
logger.info("Aplicação iniciada. Diretórios verificados.")
|
| 545 |
+
|
| 546 |
interface = create_gradio_interface()
|
| 547 |
interface.launch(
|
| 548 |
server_name="0.0.0.0",
|