File size: 6,686 Bytes
5686f5b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
import torch
try: # Support both modern and legacy Lightning imports
import lightning.pytorch as pl
except Exception: # pragma: no cover
import pytorch_lightning as pl # type: ignore
if TYPE_CHECKING: # pragma: no cover
from torchtyping import TensorType # noqa: F401
from transport_processes import results_dir
from transport_processes.data.base_databatch import TransportProcessBatch
@dataclass
class GeminiVisualizationCheckpoint(pl.Callback):
"""Full-NP-Gemini visualization logic for 1D regression plots."""
context_points: int = 10
model_label: str = "gemini"
_last_logged_epoch: int | None = field(default=None, init=False, repr=False)
def _resolve_plot_dir(self, trainer) -> Path:
root_dir = getattr(trainer, "default_root_dir", None)
if root_dir:
plot_root = Path(root_dir) / "training_images"
else:
plot_root = Path(results_dir) / "training_images"
plot_root.mkdir(parents=True, exist_ok=True)
return plot_root
def _log_figure(self, trainer, fig, image_path: str) -> None:
loggers = getattr(trainer, "loggers", None)
if loggers is None:
loggers = [trainer.logger]
for logger in loggers:
if logger is None:
continue
experiment = getattr(logger, "experiment", None)
if experiment is None:
continue
if fig is not None and hasattr(experiment, "add_figure"):
experiment.add_figure(
"val/predictions",
fig,
global_step=trainer.current_epoch,
)
elif hasattr(experiment, "log_image"):
experiment.log_image(
image_path,
name="val/predictions",
step=trainer.current_epoch,
)
def _image_logging_period(self, trainer_cfg, trainer) -> int:
max_epochs = getattr(trainer, "max_epochs", None)
if max_epochs is None or max_epochs <= 0:
max_epochs = getattr(trainer_cfg, "max_epochs", 0)
return max(1, round(max_epochs * getattr(trainer_cfg, "log_images_every_pct", 0.0)))
def _should_log_images(self, trainer_cfg, trainer) -> bool:
enabled = getattr(trainer_cfg, "log_images_every_pct", 0.0) > 0
if not enabled:
return False
period = self._image_logging_period(trainer_cfg, trainer)
return (trainer.current_epoch % period) == 0
def on_validation_epoch_end(self, trainer, pl_module) -> None:
trainer_cfg = getattr(pl_module, "trainer_cfg", None)
if trainer_cfg is None:
return
if not self._should_log_images(trainer_cfg, trainer):
return
if self._last_logged_epoch == trainer.current_epoch:
return
datamodule = getattr(trainer, "datamodule", None)
if datamodule is None:
return
try:
val_loader = datamodule.val_dataloader()
except Exception:
return
try:
batch = next(iter(val_loader))
except StopIteration:
return
if hasattr(batch, "to"):
batch = batch.to(pl_module.device)
elif isinstance(batch, (list, tuple)):
batch = [t.to(pl_module.device) if hasattr(t, "to") else t for t in batch]
output_root = self._resolve_plot_dir(trainer)
model = getattr(pl_module, "model", pl_module)
image_artifacts = self.log_images(
model,
batch,
epoch=trainer.current_epoch,
output_root=output_root,
)
for path, fig in image_artifacts:
self._log_figure(trainer, fig, str(path))
try:
import matplotlib.pyplot as plt
plt.close(fig)
except Exception:
pass
self._last_logged_epoch = trainer.current_epoch
def log_images(
self,
model,
batch: TransportProcessBatch,
*,
epoch: int,
output_root: Path | None = None,
) -> list[tuple[Path, object]]:
if getattr(model, "x_dim", None) != 1 or getattr(model, "y_dim", None) != 1:
return []
if batch.target_output is None:
return []
import matplotlib.pyplot as plt
device = batch.target_input.device
idx = 0
target_mask = batch.target_mask
if target_mask is None:
target_mask = torch.ones(batch.target_input.shape[:2], device=device, dtype=torch.bool)
x_true = batch.target_input[idx][target_mask[idx]]
y_true = batch.target_output[idx][target_mask[idx]]
if x_true.numel() == 0:
return []
n_ctx = min(int(self.context_points), x_true.size(0))
perm = torch.randperm(x_true.size(0), device=device)
ctx_idx = perm[:n_ctx]
x_c = x_true[ctx_idx].unsqueeze(0)
y_c = y_true[ctx_idx].unsqueeze(0)
x_min, x_max = getattr(model.exp_cfg.data, "gp_min_max", (-2.0, 2.0))
n_grid = int(getattr(model.exp_cfg.data, "num_of_target_grid_inputs_for_logging", 100))
x_plot = torch.linspace(x_min, x_max, n_grid, device=device).view(1, n_grid, 1)
with torch.no_grad():
mu, sigma = model.predict_from_context(x_c, y_c, x_plot)
fig = plt.figure(figsize=(10, 6))
plt.scatter(x_c[0].detach().cpu().numpy(), y_c[0].detach().cpu().numpy(), c="black", s=60)
plt.plot(
x_plot[0, :, 0].detach().cpu().numpy(),
mu[0, :, 0].detach().cpu().numpy(),
"b-",
)
sigma_np = sigma[0, :, 0].detach().cpu().numpy()
mu_np = mu[0, :, 0].detach().cpu().numpy()
x_plot_np = x_plot[0, :, 0].detach().cpu().numpy()
plt.fill_between(
x_plot_np,
mu_np - 2 * sigma_np,
mu_np + 2 * sigma_np,
color="b",
alpha=0.2,
)
plt.scatter(
x_true.detach().cpu().numpy(),
y_true.detach().cpu().numpy(),
c="gray",
alpha=0.3,
s=10,
)
plt.title(f"Epoch {epoch}")
plt.ylim(-3, 3)
output_dir = Path(output_root or Path("training_images")) / self.model_label
output_dir.mkdir(parents=True, exist_ok=True)
image_path = output_dir / f"epoch_{epoch:03d}.png"
plt.savefig(image_path)
return [(image_path, fig)]
|