import time import torch import pandas as pd from tqdm import tqdm from pathlib import Path from lightning.pytorch.callbacks import Callback from utils.latent_utils import UnscaleLatents from utils.train import SamplerConductor class SampleAndCheckpointCallback(Callback): ''' Callback to sample latents from the model and save checkpoints at specified intervals during training. Checkpoints saved here contain only model weights. ''' def __init__(self, cfg, sample_dir: Path, sample_dl, checkpoint_dir: Path, device='cuda'): super().__init__() self.cfg = cfg self.sample_dir = sample_dir self.sample_dl = sample_dl self.checkpoint_dir = checkpoint_dir self._last_sample_epoch = 0 self.device = device self.sample_dir.mkdir(parents=True, exist_ok=True) self.checkpoint_dir.mkdir(parents=True, exist_ok=True) self.conductor = SamplerConductor(cfg) def on_validation_epoch_end(self, trainer, pl_module): self._sample_step(trainer, pl_module) def on_fit_end(self, trainer, pl_module): self._sample_step(trainer, pl_module, last=True) def _sample_step(self, trainer, pl_module, last=False): pl_module.model.to(self.device) if self.sample_dir is None or trainer.sanity_checking: return epoch = trainer.current_epoch is_sample_step = self.conductor.is_sample_step( epoch=epoch, last_sample_epoch=self._last_sample_epoch, last_step=last ) if is_sample_step: if trainer.is_global_zero: out_name = 'last' if last else None sample_latents_from_model( model=pl_module.model, dl_list=[self.sample_dl], run_cfg=self.cfg, epoch=epoch, step=trainer.global_step, device=self.device, samples_dir=self.sample_dir, out_name=out_name ) self._last_sample_epoch = epoch trainer.strategy.barrier() if not last: ckpt_name = f"sample-epoch={epoch}-step={trainer.global_step}.ckpt" trainer.save_checkpoint( str(self.checkpoint_dir / ckpt_name), weights_only=True ) def _make_gen_batch(batch) -> dict: """Create a generation batch by wrapping EF values.""" wrapped_ef = ((batch['encoder_hidden_states'] + 0.5) % 1.0).clamp(0.15, 0.85) return { 'cond_image': batch['cond_image'], 'encoder_hidden_states': wrapped_ef } def sample_latents_from_model(model, dl_list, run_cfg, epoch, step, device, samples_dir, out_name=None): """Sample from the model for each DataLoader and save latent videos with metadata.""" _t0 = time.perf_counter() model.eval() C = int(run_cfg.vae.resolution.split('f')[0]) model_sample_kwargs = run_cfg.sample.get('model_sample_kwargs', {}) out_name = out_name or f"sample-epoch={epoch}-step={step}" samples_dir = Path(samples_dir) / out_name samples_dir.mkdir(parents=True, exist_ok=True) metadata_rows = [] for dl in tqdm(dl_list, desc=f"Sampling: Epoch {epoch}"): unscale = UnscaleLatents(run_cfg, dl.dataset) data_shape = tuple(dl.dataset[0]['x'].shape) nmf = dl.dataset.kwargs.get('n_missing_frames', 'max') nmf = f"{int(100*nmf)}p" if isinstance(nmf, float) else str(nmf) for batch in tqdm(dl, desc="Batches"): reference_batch, input_batch = batch batch_size = input_batch['cond_image'].shape[0] sub_batches = { 'rec': input_batch, 'gen': _make_gen_batch(input_batch) } for tag, sub_batch in sub_batches.items(): sub_batch = {k: v.to(device) for k, v in sub_batch.items()} videos = model.sample( **sub_batch, batch_size=batch_size, data_shape=data_shape, **model_sample_kwargs ).detach().cpu() videos = unscale(videos) for j, (ef, video) in enumerate(zip(sub_batch['encoder_hidden_states'][:, 0, 0].tolist(), videos)): ef = round(int(100 * ef), 2) real_name = reference_batch['video_name'][j] video_name = f"{real_name}_ef{ef}_nmf{nmf}" metadata_rows.append({ 'video_name': video_name, 'n_missing_frames': nmf, 'EF': ef, 'rec_or_gen': tag, 'original_real_video_name': real_name, 'observed_mask': reference_batch['observed_mask'][j].tolist(), 'not_pad_mask': reference_batch['not_pad_mask'][j].tolist() }) torch.save({'video': video}, samples_dir / f"{video_name}.pt") pd.DataFrame(metadata_rows).to_csv(samples_dir / 'metadata.csv', index=False) elapsed = time.perf_counter() - _t0 print(f"Sampling done: epoch={epoch}, videos={len(metadata_rows)}, " f"time={elapsed/60:.1f}m, out='{samples_dir}'") return samples_dir