EchoLVFM / src /flows /linear_flow.py
EngEmmanuel's picture
Initial commit
0f5513d
import torch
import torch.nn as nn
from torch.nn import Module
from torch import Tensor
from torchdiffeq import odeint
from einops import rearrange, repeat
from src.custom_loss import MaskedMSELoss
# Code adapted from https://github.com/lucidrains/rectified-flow-pytorch/blob/main/rectified_flow_pytorch/rectified_flow.py
def identity(t):
return t
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
# tensor helpers
def append_dims(t, ndims):
shape = t.shape
return t.reshape(*shape, *((1,) * ndims))
class LinearFlow(Module):
def __init__(
self,
model,
data_shape: tuple[int, ...] | None = None,
clip_values: tuple[float, float] | None = None,
clip_flow_values: tuple[float, float] | None = None,
**kwargs
):
super().__init__()
self.model = model
self.data_shape = data_shape
self.noise_schedule = lambda x: x
self.clip_values = clip_values
self.clip_flow_values = clip_flow_values
self.loss_fn = MaskedMSELoss()
# objective - either flow or noise. CHOSE TO PREDICT FLOW
# self.predict = predict
@property
def device(self):
return next(self.model.parameters()).device
def sample_times(self, batch):
pass
@torch.no_grad()
def sample(
self,
encoder_hidden_states: torch.Tensor,
batch_size=1,
steps=16,
noise=None,
data_shape: tuple[int, ...] | None = None,
cond_image=None,
mask=None,
guidance_scale: float = 1.0,
odeint_kwargs: dict = dict(
atol = 1e-5,
rtol = 1e-5,
method = 'midpoint'
),
use_ema: bool = False,
**model_kwargs
):
model = self.model
data_shape = default(data_shape, self.data_shape)
print(f'Sampling with steps={steps}, batch_size={batch_size}, guidance_scale={guidance_scale}')
maybe_clip = (lambda t: t.clamp_(*self.clip_values)) if self.clip_values is not None else identity
maybe_clip_flow = (lambda t: t.clamp_(*self.clip_flow_values)) if self.clip_flow_values is not None else identity
# Backward-compatible lookup for learned null embedding: prefer flow.null_ehs, fallback to base model.null_ehs
uncond_ehs = getattr(self, "null_ehs", None)
if uncond_ehs is None:
uncond_ehs = getattr(self.model, "null_ehs", None)
if uncond_ehs is not None:
# Get underlying tensor
uncond = uncond_ehs.data if isinstance(uncond_ehs, torch.nn.Parameter) else uncond_ehs
# Try to match encoder_hidden_states shape (excluding batch)
target_tail = tuple(encoder_hidden_states.shape[1:]) if hasattr(encoder_hidden_states, 'shape') else None
if target_tail and uncond.shape != target_tail:
try:
uncond = uncond.view(*target_tail)
except Exception:
# leave as-is; expand best-effort below
pass
# Expand along batch dimension
if uncond.dim() == 0:
uncond = uncond.view(1, 1)
if uncond.dim() == 1:
uncond_ehs = uncond.unsqueeze(0).expand(batch_size, -1)
elif uncond.dim() == 2:
uncond_ehs = uncond.unsqueeze(0).expand(batch_size, -1, -1)
else:
uncond_ehs = uncond.unsqueeze(0).expand(batch_size, *uncond.shape)
def _predict(x, t, ehs):
return self.predict_flow(
model,
x,
times=t,
encoder_hidden_states=ehs,
cond_image=cond_image,
mask=mask,
**model_kwargs,
)
def ode_fn(t, x):
x = maybe_clip(x)
if guidance_scale <= 1.0: # No CFG
flow = _predict(x, t, encoder_hidden_states)
else:
if uncond_ehs is None:
raise ValueError(
"guidance_scale > 1.0 requires a learned null EF embedding. "
"Either this model was not trained for CFG or you need to" \
"Attach `null_ehs` to the flow (e.g., during checkpoint load)."
)
flow_cond = _predict(x, t, encoder_hidden_states)
flow_uncond = _predict(x, t, uncond_ehs)
flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
return maybe_clip_flow(flow)
# Start with random gaussian noise - y0
noise = default(noise, torch.randn(batch_size, *data_shape, device=self.device))
# time steps
time_steps = torch.linspace(0., 1., steps, device=self.device)
# ode
trajectory = odeint(ode_fn, noise, time_steps, **odeint_kwargs)
sampled_data = trajectory[-1] # Get the last state as the sampled data
return sampled_data
# Keep model arg in case of ema
def predict_flow(self,
model:Module,
noised,
*,
times,
encoder_hidden_states=None,
cond_image=None,
mask=None,
eps=1e-10,
**model_kwargs
):
batch = noised.shape[0]
# Prepare time conditioning for model
times = rearrange(times, '... -> (...)') # Flattens times
if times.numel() == 1:
times = repeat(times, '1 -> b', b = batch)
# Unet and STDiT forward(x, timestep, encoder_hidden_states=None, cond_image=None, mask=None, return_dict=True)
output = self.model(x=noised, timestep=times, encoder_hidden_states=encoder_hidden_states, cond_image=cond_image, mask=mask, **model_kwargs) # predicted flow / velocity field
if hasattr(output, 'sample'):
return output.sample
return output
def forward(
self,
x,
encoder_hidden_states: torch.Tensor,
noise: Tensor | None = None,
cond_image=None,
mask=None,
loss_mask=None,
**model_kwargs
):
batch, *data_shape = x.shape
self.data_shape = default(self.data_shape, data_shape)
# x0 - gaussian noise, x1 - data
noise = default(noise, torch.randn_like(x))
times = torch.rand(batch, device = self.device)
padded_times = append_dims(times, x.ndim - 1)
def get_noised_and_flows(model, t):
# maybe noise schedule
t = self.noise_schedule(t)
noised = x * t + noise * (1 - t)
flow = x - noise
pred_flow = self.predict_flow(model, noised, times=t, encoder_hidden_states=encoder_hidden_states, cond_image=cond_image, **model_kwargs)
pred_x = noised + pred_flow * (1 - t)
return flow, pred_flow, pred_x
# getting flow and pred flow for main model
flow, pred_flow, pred_x = get_noised_and_flows(self.model, padded_times)
main_loss = self.loss_fn(pred_flow, flow, loss_mask) #, pred_data = pred_x, times = times, data = x)
return main_loss