Spaces:
Runtime error
Runtime error
| """ | |
| Flow matching dynamics model — inference-only subset for HuggingFace Spaces. | |
| Contains only the generate_next_frame function needed for the demo. | |
| """ | |
| import torch | |
| from src.models.unet import UNet | |
| def generate_next_frame( | |
| model: UNet, | |
| context_frames: torch.Tensor, | |
| action: torch.Tensor, | |
| num_steps: int = 15, | |
| cfg_scale: float = 2.0, | |
| ) -> torch.Tensor: | |
| """Generate the next frame using flow matching ODE integration. | |
| Uses Euler integration along the learned velocity field with | |
| classifier-free guidance. | |
| Args: | |
| model: Trained U-Net dynamics model (in eval mode). | |
| context_frames: [B, H*3, h, w] context frames. | |
| action: [B] action indices. | |
| num_steps: Number of Euler integration steps. | |
| cfg_scale: Classifier-free guidance scale. 1.0 = no guidance. | |
| Returns: | |
| Predicted next frame [B, 3, h, w], clamped to [0, 1]. | |
| """ | |
| B = context_frames.shape[0] | |
| h, w = context_frames.shape[2], context_frames.shape[3] | |
| device = context_frames.device | |
| # Start from pure noise | |
| x = torch.randn(B, 3, h, w, device=device) | |
| dt = 1.0 / num_steps | |
| for i in range(num_steps): | |
| t = torch.full((B,), i * dt, device=device) | |
| model_input = torch.cat([x, context_frames], dim=1) # [B, 15, h, w] | |
| # Conditional velocity (with real action) | |
| v_cond = model(model_input, t, action) | |
| if cfg_scale != 1.0: | |
| # Unconditional velocity (null action — zeros) | |
| v_uncond = model(model_input, t, torch.zeros_like(action)) | |
| # CFG: steer toward conditional prediction | |
| v = v_uncond + cfg_scale * (v_cond - v_uncond) | |
| else: | |
| v = v_cond | |
| # Euler step | |
| x = x + dt * v | |
| return x.clamp(0, 1) | |