Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn as nn | |
| from torch.nn import Module | |
| from torch import Tensor | |
| from torchdiffeq import odeint | |
| from einops import rearrange, repeat | |
| from src.custom_loss import MaskedMSELoss | |
| # Code adapted from https://github.com/lucidrains/rectified-flow-pytorch/blob/main/rectified_flow_pytorch/rectified_flow.py | |
| def identity(t): | |
| return t | |
| def exists(v): | |
| return v is not None | |
| def default(v, d): | |
| return v if exists(v) else d | |
| # tensor helpers | |
| def append_dims(t, ndims): | |
| shape = t.shape | |
| return t.reshape(*shape, *((1,) * ndims)) | |
| class LinearFlow(Module): | |
| def __init__( | |
| self, | |
| model, | |
| data_shape: tuple[int, ...] | None = None, | |
| clip_values: tuple[float, float] | None = None, | |
| clip_flow_values: tuple[float, float] | None = None, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| self.model = model | |
| self.data_shape = data_shape | |
| self.noise_schedule = lambda x: x | |
| self.clip_values = clip_values | |
| self.clip_flow_values = clip_flow_values | |
| self.loss_fn = MaskedMSELoss() | |
| # objective - either flow or noise. CHOSE TO PREDICT FLOW | |
| # self.predict = predict | |
| def device(self): | |
| return next(self.model.parameters()).device | |
| def sample_times(self, batch): | |
| pass | |
| def sample( | |
| self, | |
| encoder_hidden_states: torch.Tensor, | |
| batch_size=1, | |
| steps=16, | |
| noise=None, | |
| data_shape: tuple[int, ...] | None = None, | |
| cond_image=None, | |
| mask=None, | |
| guidance_scale: float = 1.0, | |
| odeint_kwargs: dict = dict( | |
| atol = 1e-5, | |
| rtol = 1e-5, | |
| method = 'midpoint' | |
| ), | |
| use_ema: bool = False, | |
| **model_kwargs | |
| ): | |
| model = self.model | |
| data_shape = default(data_shape, self.data_shape) | |
| print(f'Sampling with steps={steps}, batch_size={batch_size}, guidance_scale={guidance_scale}') | |
| maybe_clip = (lambda t: t.clamp_(*self.clip_values)) if self.clip_values is not None else identity | |
| maybe_clip_flow = (lambda t: t.clamp_(*self.clip_flow_values)) if self.clip_flow_values is not None else identity | |
| # Backward-compatible lookup for learned null embedding: prefer flow.null_ehs, fallback to base model.null_ehs | |
| uncond_ehs = getattr(self, "null_ehs", None) | |
| if uncond_ehs is None: | |
| uncond_ehs = getattr(self.model, "null_ehs", None) | |
| if uncond_ehs is not None: | |
| # Get underlying tensor | |
| uncond = uncond_ehs.data if isinstance(uncond_ehs, torch.nn.Parameter) else uncond_ehs | |
| # Try to match encoder_hidden_states shape (excluding batch) | |
| target_tail = tuple(encoder_hidden_states.shape[1:]) if hasattr(encoder_hidden_states, 'shape') else None | |
| if target_tail and uncond.shape != target_tail: | |
| try: | |
| uncond = uncond.view(*target_tail) | |
| except Exception: | |
| # leave as-is; expand best-effort below | |
| pass | |
| # Expand along batch dimension | |
| if uncond.dim() == 0: | |
| uncond = uncond.view(1, 1) | |
| if uncond.dim() == 1: | |
| uncond_ehs = uncond.unsqueeze(0).expand(batch_size, -1) | |
| elif uncond.dim() == 2: | |
| uncond_ehs = uncond.unsqueeze(0).expand(batch_size, -1, -1) | |
| else: | |
| uncond_ehs = uncond.unsqueeze(0).expand(batch_size, *uncond.shape) | |
| def _predict(x, t, ehs): | |
| return self.predict_flow( | |
| model, | |
| x, | |
| times=t, | |
| encoder_hidden_states=ehs, | |
| cond_image=cond_image, | |
| mask=mask, | |
| **model_kwargs, | |
| ) | |
| def ode_fn(t, x): | |
| x = maybe_clip(x) | |
| if guidance_scale <= 1.0: # No CFG | |
| flow = _predict(x, t, encoder_hidden_states) | |
| else: | |
| if uncond_ehs is None: | |
| raise ValueError( | |
| "guidance_scale > 1.0 requires a learned null EF embedding. " | |
| "Either this model was not trained for CFG or you need to" \ | |
| "Attach `null_ehs` to the flow (e.g., during checkpoint load)." | |
| ) | |
| flow_cond = _predict(x, t, encoder_hidden_states) | |
| flow_uncond = _predict(x, t, uncond_ehs) | |
| flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond) | |
| return maybe_clip_flow(flow) | |
| # Start with random gaussian noise - y0 | |
| noise = default(noise, torch.randn(batch_size, *data_shape, device=self.device)) | |
| # time steps | |
| time_steps = torch.linspace(0., 1., steps, device=self.device) | |
| # ode | |
| trajectory = odeint(ode_fn, noise, time_steps, **odeint_kwargs) | |
| sampled_data = trajectory[-1] # Get the last state as the sampled data | |
| return sampled_data | |
| # Keep model arg in case of ema | |
| def predict_flow(self, | |
| model:Module, | |
| noised, | |
| *, | |
| times, | |
| encoder_hidden_states=None, | |
| cond_image=None, | |
| mask=None, | |
| eps=1e-10, | |
| **model_kwargs | |
| ): | |
| batch = noised.shape[0] | |
| # Prepare time conditioning for model | |
| times = rearrange(times, '... -> (...)') # Flattens times | |
| if times.numel() == 1: | |
| times = repeat(times, '1 -> b', b = batch) | |
| # Unet and STDiT forward(x, timestep, encoder_hidden_states=None, cond_image=None, mask=None, return_dict=True) | |
| output = self.model(x=noised, timestep=times, encoder_hidden_states=encoder_hidden_states, cond_image=cond_image, mask=mask, **model_kwargs) # predicted flow / velocity field | |
| if hasattr(output, 'sample'): | |
| return output.sample | |
| return output | |
| def forward( | |
| self, | |
| x, | |
| encoder_hidden_states: torch.Tensor, | |
| noise: Tensor | None = None, | |
| cond_image=None, | |
| mask=None, | |
| loss_mask=None, | |
| **model_kwargs | |
| ): | |
| batch, *data_shape = x.shape | |
| self.data_shape = default(self.data_shape, data_shape) | |
| # x0 - gaussian noise, x1 - data | |
| noise = default(noise, torch.randn_like(x)) | |
| times = torch.rand(batch, device = self.device) | |
| padded_times = append_dims(times, x.ndim - 1) | |
| def get_noised_and_flows(model, t): | |
| # maybe noise schedule | |
| t = self.noise_schedule(t) | |
| noised = x * t + noise * (1 - t) | |
| flow = x - noise | |
| pred_flow = self.predict_flow(model, noised, times=t, encoder_hidden_states=encoder_hidden_states, cond_image=cond_image, **model_kwargs) | |
| pred_x = noised + pred_flow * (1 - t) | |
| return flow, pred_flow, pred_x | |
| # getting flow and pred flow for main model | |
| flow, pred_flow, pred_x = get_noised_and_flows(self.model, padded_times) | |
| main_loss = self.loss_fn(pred_flow, flow, loss_mask) #, pred_data = pred_x, times = times, data = x) | |
| return main_loss |