Spaces:
Running on Zero
Running on Zero
File size: 5,539 Bytes
0f5513d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import time
import torch
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from lightning.pytorch.callbacks import Callback
from utils.latent_utils import UnscaleLatents
from utils.train import SamplerConductor
class SampleAndCheckpointCallback(Callback):
'''
Callback to sample latents from the model and save checkpoints at specified
intervals during training. Checkpoints saved here contain only model weights.
'''
def __init__(self, cfg, sample_dir: Path, sample_dl, checkpoint_dir: Path, device='cuda'):
super().__init__()
self.cfg = cfg
self.sample_dir = sample_dir
self.sample_dl = sample_dl
self.checkpoint_dir = checkpoint_dir
self._last_sample_epoch = 0
self.device = device
self.sample_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.conductor = SamplerConductor(cfg)
def on_validation_epoch_end(self, trainer, pl_module):
self._sample_step(trainer, pl_module)
def on_fit_end(self, trainer, pl_module):
self._sample_step(trainer, pl_module, last=True)
def _sample_step(self, trainer, pl_module, last=False):
pl_module.model.to(self.device)
if self.sample_dir is None or trainer.sanity_checking:
return
epoch = trainer.current_epoch
is_sample_step = self.conductor.is_sample_step(
epoch=epoch,
last_sample_epoch=self._last_sample_epoch,
last_step=last
)
if is_sample_step:
if trainer.is_global_zero:
out_name = 'last' if last else None
sample_latents_from_model(
model=pl_module.model,
dl_list=[self.sample_dl],
run_cfg=self.cfg,
epoch=epoch,
step=trainer.global_step,
device=self.device,
samples_dir=self.sample_dir,
out_name=out_name
)
self._last_sample_epoch = epoch
trainer.strategy.barrier()
if not last:
ckpt_name = f"sample-epoch={epoch}-step={trainer.global_step}.ckpt"
trainer.save_checkpoint(
str(self.checkpoint_dir / ckpt_name),
weights_only=True
)
def _make_gen_batch(batch) -> dict:
"""Create a generation batch by wrapping EF values."""
wrapped_ef = ((batch['encoder_hidden_states'] + 0.5) % 1.0).clamp(0.15, 0.85)
return {
'cond_image': batch['cond_image'],
'encoder_hidden_states': wrapped_ef
}
def sample_latents_from_model(model, dl_list, run_cfg, epoch, step, device, samples_dir, out_name=None):
"""Sample from the model for each DataLoader and save latent videos with metadata."""
_t0 = time.perf_counter()
model.eval()
C = int(run_cfg.vae.resolution.split('f')[0])
model_sample_kwargs = run_cfg.sample.get('model_sample_kwargs', {})
out_name = out_name or f"sample-epoch={epoch}-step={step}"
samples_dir = Path(samples_dir) / out_name
samples_dir.mkdir(parents=True, exist_ok=True)
metadata_rows = []
for dl in tqdm(dl_list, desc=f"Sampling: Epoch {epoch}"):
unscale = UnscaleLatents(run_cfg, dl.dataset)
data_shape = tuple(dl.dataset[0]['x'].shape)
nmf = dl.dataset.kwargs.get('n_missing_frames', 'max')
nmf = f"{int(100*nmf)}p" if isinstance(nmf, float) else str(nmf)
for batch in tqdm(dl, desc="Batches"):
reference_batch, input_batch = batch
batch_size = input_batch['cond_image'].shape[0]
sub_batches = {
'rec': input_batch,
'gen': _make_gen_batch(input_batch)
}
for tag, sub_batch in sub_batches.items():
sub_batch = {k: v.to(device) for k, v in sub_batch.items()}
videos = model.sample(
**sub_batch,
batch_size=batch_size,
data_shape=data_shape,
**model_sample_kwargs
).detach().cpu()
videos = unscale(videos)
for j, (ef, video) in enumerate(zip(sub_batch['encoder_hidden_states'][:, 0, 0].tolist(), videos)):
ef = round(int(100 * ef), 2)
real_name = reference_batch['video_name'][j]
video_name = f"{real_name}_ef{ef}_nmf{nmf}"
metadata_rows.append({
'video_name': video_name,
'n_missing_frames': nmf,
'EF': ef,
'rec_or_gen': tag,
'original_real_video_name': real_name,
'observed_mask': reference_batch['observed_mask'][j].tolist(),
'not_pad_mask': reference_batch['not_pad_mask'][j].tolist()
})
torch.save({'video': video}, samples_dir / f"{video_name}.pt")
pd.DataFrame(metadata_rows).to_csv(samples_dir / 'metadata.csv', index=False)
elapsed = time.perf_counter() - _t0
print(f"Sampling done: epoch={epoch}, videos={len(metadata_rows)}, "
f"time={elapsed/60:.1f}m, out='{samples_dir}'")
return samples_dir
|