"""Motor Dynamics Module for MANIFOLD. Neural ODE with fixed Euler solver for modeling motor dynamics. Integrates physics constraints to enforce human biomechanical limits. """ from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Dict, Any, TYPE_CHECKING from manifold.models.layers.physics import PhysicsConstraintLayer if TYPE_CHECKING: from manifold.config import ModelConfig class NeuralODEFunc(nn.Module): """Neural network defining the ODE dynamics dz/dt = f(z, t). Implements a 3-layer MLP with GELU activations to learn continuous-time dynamics in the latent space. """ def __init__(self, hidden_dim: int = 512, input_dim: int = 256): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, input_dim), ) def forward(self, t: torch.Tensor, z: torch.Tensor) -> torch.Tensor: """Compute dz/dt. Args: t: Current time (scalar tensor, unused but required for ODE interface) z: Current state [batch, seq, input_dim] or [batch, input_dim] Returns: dz/dt with same shape as z """ return self.net(z) def fixed_euler_solve( func: NeuralODEFunc, z0: torch.Tensor, t_span: tuple = (0.0, 1.0), num_steps: int = 4, ) -> torch.Tensor: """Fixed-step Euler ODE solver. Much more memory efficient than adaptive solvers like dopri5. Uses simple forward Euler: z_{n+1} = z_n + dt * f(t_n, z_n) Args: func: Neural ODE function computing dz/dt z0: Initial state [batch, ...] t_span: Integration interval (t0, t1) num_steps: Number of Euler steps (default 4 for memory efficiency) Returns: Final state z(t1) with same shape as z0 """ dt = (t_span[1] - t_span[0]) / num_steps z = z0 t = t_span[0] for _ in range(num_steps): t_tensor = torch.tensor(t, device=z.device, dtype=z.dtype) dz = func(t_tensor, z) z = z + dt * dz t += dt return z class MotorDynamicsModule(nn.Module): """Neural ODE module for motor dynamics with physics constraints. Uses fixed 4-step Euler solver for memory efficiency during training. Integrates physics constraints (jerk, turn rate, acceleration limits) to ensure learned dynamics respect human biomechanical limits. Args: input_dim: Dimension of input features (default 256) hidden_dim: Hidden dimension for ODE function (default 512) num_steps: Number of Euler integration steps (default 4) use_physics_constraints: Whether to apply physics constraints (default True) """ def __init__( self, input_dim: int = 256, hidden_dim: int = 512, num_steps: int = 4, use_physics_constraints: bool = True, ): super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.num_steps = num_steps self.use_physics_constraints = use_physics_constraints self.ode_func = NeuralODEFunc(hidden_dim=hidden_dim, input_dim=input_dim) if use_physics_constraints: self.physics = PhysicsConstraintLayer(learnable=True) else: self.physics = None self.output_proj = nn.Linear(input_dim, input_dim) def forward( self, x: torch.Tensor, trajectory: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """Forward pass through motor dynamics module. Args: x: Input tensor [batch, seq, input_dim] trajectory: Optional trajectory for physics violation computation [batch, seq, 2] mouse deltas (dx, dy) Returns: Dict containing: - "output": Transformed features [batch, seq, input_dim] - "physics_violations": Physics constraint violations (if applicable) """ z = fixed_euler_solve( func=self.ode_func, z0=x, t_span=(0.0, 1.0), num_steps=self.num_steps, ) output = self.output_proj(z) result = {"output": output} if self.use_physics_constraints and self.physics is not None: if trajectory is not None: physics_violations = self.physics(trajectory) else: dummy_trajectory = output[..., :2].detach() physics_violations = self.physics(dummy_trajectory) result["physics_violations"] = physics_violations else: result["physics_violations"] = {"total_violation": torch.tensor(0.0, device=x.device)} return result @classmethod def from_config(cls, config: "ModelConfig") -> "MotorDynamicsModule": """Create MotorDynamicsModule from ModelConfig. Args: config: Model configuration object Returns: Configured MotorDynamicsModule instance """ return cls( input_dim=config.embed_dim, hidden_dim=config.mdm_hidden, num_steps=config.mdm_steps, use_physics_constraints=True, )