Spaces:
Running on Zero
Running on Zero
| """Motor Dynamics Module for MANIFOLD. | |
| Neural ODE with fixed Euler solver for modeling motor dynamics. | |
| Integrates physics constraints to enforce human biomechanical limits. | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Dict, Any, TYPE_CHECKING | |
| from manifold.models.layers.physics import PhysicsConstraintLayer | |
| if TYPE_CHECKING: | |
| from manifold.config import ModelConfig | |
| class NeuralODEFunc(nn.Module): | |
| """Neural network defining the ODE dynamics dz/dt = f(z, t). | |
| Implements a 3-layer MLP with GELU activations to learn | |
| continuous-time dynamics in the latent space. | |
| """ | |
| def __init__(self, hidden_dim: int = 512, input_dim: int = 256): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Linear(hidden_dim, input_dim), | |
| ) | |
| def forward(self, t: torch.Tensor, z: torch.Tensor) -> torch.Tensor: | |
| """Compute dz/dt. | |
| Args: | |
| t: Current time (scalar tensor, unused but required for ODE interface) | |
| z: Current state [batch, seq, input_dim] or [batch, input_dim] | |
| Returns: | |
| dz/dt with same shape as z | |
| """ | |
| return self.net(z) | |
| def fixed_euler_solve( | |
| func: NeuralODEFunc, | |
| z0: torch.Tensor, | |
| t_span: tuple = (0.0, 1.0), | |
| num_steps: int = 4, | |
| ) -> torch.Tensor: | |
| """Fixed-step Euler ODE solver. | |
| Much more memory efficient than adaptive solvers like dopri5. | |
| Uses simple forward Euler: z_{n+1} = z_n + dt * f(t_n, z_n) | |
| Args: | |
| func: Neural ODE function computing dz/dt | |
| z0: Initial state [batch, ...] | |
| t_span: Integration interval (t0, t1) | |
| num_steps: Number of Euler steps (default 4 for memory efficiency) | |
| Returns: | |
| Final state z(t1) with same shape as z0 | |
| """ | |
| dt = (t_span[1] - t_span[0]) / num_steps | |
| z = z0 | |
| t = t_span[0] | |
| for _ in range(num_steps): | |
| t_tensor = torch.tensor(t, device=z.device, dtype=z.dtype) | |
| dz = func(t_tensor, z) | |
| z = z + dt * dz | |
| t += dt | |
| return z | |
| class MotorDynamicsModule(nn.Module): | |
| """Neural ODE module for motor dynamics with physics constraints. | |
| Uses fixed 4-step Euler solver for memory efficiency during training. | |
| Integrates physics constraints (jerk, turn rate, acceleration limits) | |
| to ensure learned dynamics respect human biomechanical limits. | |
| Args: | |
| input_dim: Dimension of input features (default 256) | |
| hidden_dim: Hidden dimension for ODE function (default 512) | |
| num_steps: Number of Euler integration steps (default 4) | |
| use_physics_constraints: Whether to apply physics constraints (default True) | |
| """ | |
| def __init__( | |
| self, | |
| input_dim: int = 256, | |
| hidden_dim: int = 512, | |
| num_steps: int = 4, | |
| use_physics_constraints: bool = True, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.hidden_dim = hidden_dim | |
| self.num_steps = num_steps | |
| self.use_physics_constraints = use_physics_constraints | |
| self.ode_func = NeuralODEFunc(hidden_dim=hidden_dim, input_dim=input_dim) | |
| if use_physics_constraints: | |
| self.physics = PhysicsConstraintLayer(learnable=True) | |
| else: | |
| self.physics = None | |
| self.output_proj = nn.Linear(input_dim, input_dim) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| trajectory: Optional[torch.Tensor] = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| """Forward pass through motor dynamics module. | |
| Args: | |
| x: Input tensor [batch, seq, input_dim] | |
| trajectory: Optional trajectory for physics violation computation | |
| [batch, seq, 2] mouse deltas (dx, dy) | |
| Returns: | |
| Dict containing: | |
| - "output": Transformed features [batch, seq, input_dim] | |
| - "physics_violations": Physics constraint violations (if applicable) | |
| """ | |
| z = fixed_euler_solve( | |
| func=self.ode_func, | |
| z0=x, | |
| t_span=(0.0, 1.0), | |
| num_steps=self.num_steps, | |
| ) | |
| output = self.output_proj(z) | |
| result = {"output": output} | |
| if self.use_physics_constraints and self.physics is not None: | |
| if trajectory is not None: | |
| physics_violations = self.physics(trajectory) | |
| else: | |
| dummy_trajectory = output[..., :2].detach() | |
| physics_violations = self.physics(dummy_trajectory) | |
| result["physics_violations"] = physics_violations | |
| else: | |
| result["physics_violations"] = {"total_violation": torch.tensor(0.0, device=x.device)} | |
| return result | |
| def from_config(cls, config: "ModelConfig") -> "MotorDynamicsModule": | |
| """Create MotorDynamicsModule from ModelConfig. | |
| Args: | |
| config: Model configuration object | |
| Returns: | |
| Configured MotorDynamicsModule instance | |
| """ | |
| return cls( | |
| input_dim=config.embed_dim, | |
| hidden_dim=config.mdm_hidden, | |
| num_steps=config.mdm_steps, | |
| use_physics_constraints=True, | |
| ) | |