EchoLVFM / utils /train.py
EngEmmanuel's picture
Initial commit
0f5513d
from src.model import UNet3D
from src.flows import LinearFlow, RMMFlow
from vae.util import load_vae_and_processor
def load_model(cfg, dummy_data, device):
if isinstance(dummy_data, dict):
C, T, H, W = dummy_data['x'].shape
Cc, _, _, _ = dummy_data['cond_image'].shape
print(f'Input shape: {(C+Cc, T, H, W)}, with {Cc} cond channels')
match cfg.model.type.lower():
case "unet":
return UNet3D(
sample_size=W,
in_channels=C + Cc,
out_channels=C,
num_frames=T,
**cfg.model.kwargs
).to(device)
raise ValueError(f"Unsupported model type: {cfg.model.type}")
def load_flow(cfg, model):
match cfg.flow.type.lower():
case "linear":
return LinearFlow(model=model, **cfg.flow.get('kwargs', {}))
case "mean":
return RMMFlow(model=model, **cfg.flow.get('kwargs', {}))
raise ValueError(f"Unsupported flow type: {cfg.flow.type}")
def load_vae_processor(cfg, device):
return load_vae_and_processor(
vae_locator=cfg.vae.repo_id,
subfolder=cfg.vae.get('subfolder', None),
device=device
)
class SamplerConductor:
"""Controls how often the model is sampled during training."""
def __init__(self, run_cfg):
self.max_epochs = run_cfg.trainer.kwargs.max_epochs
self.sample_every_n_epochs = run_cfg.sample.get('every_n_epochs', self.max_epochs)
self.scheduler = run_cfg.trainer.get('lr_scheduler', None)
def is_sample_step(self, epoch, last_sample_epoch, last_step):
if last_step:
return True
epoch_freq = self.sample_every_n_epochs
# Sample more frequently near the cosine annealing transition region
if self.scheduler == 'cosineannealing':
progress = epoch / self.max_epochs
roi = (0.4, 0.55)
if roi[0] <= progress <= roi[1]:
epoch_freq = max(1, self.sample_every_n_epochs // 2)
return epoch % epoch_freq == 0 and epoch != last_sample_epoch