Spaces:
Sleeping
Sleeping
update LatentNoiseTrainer class to include iteration callbacks
Browse files- training/trainer.py +4 -1
training/trainer.py
CHANGED
|
@@ -51,6 +51,7 @@ class LatentNoiseTrainer:
|
|
| 51 |
prompt: str,
|
| 52 |
optimizer: torch.optim.Optimizer,
|
| 53 |
save_dir: Optional[str] = None,
|
|
|
|
| 54 |
) -> Tuple[PIL.Image.Image, Dict[str, float], Dict[str, float]]:
|
| 55 |
logging.info(f"Optimizing latents for prompt '{prompt}'.")
|
| 56 |
best_loss = torch.inf
|
|
@@ -120,6 +121,8 @@ class LatentNoiseTrainer:
|
|
| 120 |
image_numpy = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
| 121 |
image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
|
| 122 |
image_pil.save(f"{save_dir}/{iteration}.png")
|
|
|
|
|
|
|
| 123 |
image_numpy = best_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
| 124 |
image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
|
| 125 |
-
return image_pil, initial_rewards, best_rewards
|
|
|
|
| 51 |
prompt: str,
|
| 52 |
optimizer: torch.optim.Optimizer,
|
| 53 |
save_dir: Optional[str] = None,
|
| 54 |
+
progress_callback=None,
|
| 55 |
) -> Tuple[PIL.Image.Image, Dict[str, float], Dict[str, float]]:
|
| 56 |
logging.info(f"Optimizing latents for prompt '{prompt}'.")
|
| 57 |
best_loss = torch.inf
|
|
|
|
| 121 |
image_numpy = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
| 122 |
image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
|
| 123 |
image_pil.save(f"{save_dir}/{iteration}.png")
|
| 124 |
+
if progress_callback:
|
| 125 |
+
progress_callback(iteration + 1)
|
| 126 |
image_numpy = best_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
| 127 |
image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
|
| 128 |
+
return image_pil, initial_rewards, best_rewards
|