EchoLVFM / src /custom_callbacks.py
EngEmmanuel's picture
Initial commit
0f5513d
import time
import torch
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from lightning.pytorch.callbacks import Callback
from utils.latent_utils import UnscaleLatents
from utils.train import SamplerConductor
class SampleAndCheckpointCallback(Callback):
'''
Callback to sample latents from the model and save checkpoints at specified
intervals during training. Checkpoints saved here contain only model weights.
'''
def __init__(self, cfg, sample_dir: Path, sample_dl, checkpoint_dir: Path, device='cuda'):
super().__init__()
self.cfg = cfg
self.sample_dir = sample_dir
self.sample_dl = sample_dl
self.checkpoint_dir = checkpoint_dir
self._last_sample_epoch = 0
self.device = device
self.sample_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.conductor = SamplerConductor(cfg)
def on_validation_epoch_end(self, trainer, pl_module):
self._sample_step(trainer, pl_module)
def on_fit_end(self, trainer, pl_module):
self._sample_step(trainer, pl_module, last=True)
def _sample_step(self, trainer, pl_module, last=False):
pl_module.model.to(self.device)
if self.sample_dir is None or trainer.sanity_checking:
return
epoch = trainer.current_epoch
is_sample_step = self.conductor.is_sample_step(
epoch=epoch,
last_sample_epoch=self._last_sample_epoch,
last_step=last
)
if is_sample_step:
if trainer.is_global_zero:
out_name = 'last' if last else None
sample_latents_from_model(
model=pl_module.model,
dl_list=[self.sample_dl],
run_cfg=self.cfg,
epoch=epoch,
step=trainer.global_step,
device=self.device,
samples_dir=self.sample_dir,
out_name=out_name
)
self._last_sample_epoch = epoch
trainer.strategy.barrier()
if not last:
ckpt_name = f"sample-epoch={epoch}-step={trainer.global_step}.ckpt"
trainer.save_checkpoint(
str(self.checkpoint_dir / ckpt_name),
weights_only=True
)
def _make_gen_batch(batch) -> dict:
"""Create a generation batch by wrapping EF values."""
wrapped_ef = ((batch['encoder_hidden_states'] + 0.5) % 1.0).clamp(0.15, 0.85)
return {
'cond_image': batch['cond_image'],
'encoder_hidden_states': wrapped_ef
}
def sample_latents_from_model(model, dl_list, run_cfg, epoch, step, device, samples_dir, out_name=None):
"""Sample from the model for each DataLoader and save latent videos with metadata."""
_t0 = time.perf_counter()
model.eval()
C = int(run_cfg.vae.resolution.split('f')[0])
model_sample_kwargs = run_cfg.sample.get('model_sample_kwargs', {})
out_name = out_name or f"sample-epoch={epoch}-step={step}"
samples_dir = Path(samples_dir) / out_name
samples_dir.mkdir(parents=True, exist_ok=True)
metadata_rows = []
for dl in tqdm(dl_list, desc=f"Sampling: Epoch {epoch}"):
unscale = UnscaleLatents(run_cfg, dl.dataset)
data_shape = tuple(dl.dataset[0]['x'].shape)
nmf = dl.dataset.kwargs.get('n_missing_frames', 'max')
nmf = f"{int(100*nmf)}p" if isinstance(nmf, float) else str(nmf)
for batch in tqdm(dl, desc="Batches"):
reference_batch, input_batch = batch
batch_size = input_batch['cond_image'].shape[0]
sub_batches = {
'rec': input_batch,
'gen': _make_gen_batch(input_batch)
}
for tag, sub_batch in sub_batches.items():
sub_batch = {k: v.to(device) for k, v in sub_batch.items()}
videos = model.sample(
**sub_batch,
batch_size=batch_size,
data_shape=data_shape,
**model_sample_kwargs
).detach().cpu()
videos = unscale(videos)
for j, (ef, video) in enumerate(zip(sub_batch['encoder_hidden_states'][:, 0, 0].tolist(), videos)):
ef = round(int(100 * ef), 2)
real_name = reference_batch['video_name'][j]
video_name = f"{real_name}_ef{ef}_nmf{nmf}"
metadata_rows.append({
'video_name': video_name,
'n_missing_frames': nmf,
'EF': ef,
'rec_or_gen': tag,
'original_real_video_name': real_name,
'observed_mask': reference_batch['observed_mask'][j].tolist(),
'not_pad_mask': reference_batch['not_pad_mask'][j].tolist()
})
torch.save({'video': video}, samples_dir / f"{video_name}.pt")
pd.DataFrame(metadata_rows).to_csv(samples_dir / 'metadata.csv', index=False)
elapsed = time.perf_counter() - _t0
print(f"Sampling done: epoch={epoch}, videos={len(metadata_rows)}, "
f"time={elapsed/60:.1f}m, out='{samples_dir}'")
return samples_dir