minigenie / src /training /train_dynamics.py
BrutalCaesar's picture
🧞 Deploy MiniGenie — interactive flow matching world model demo
f805fb3
"""
Flow matching dynamics model — inference-only subset for HuggingFace Spaces.
Contains only the generate_next_frame function needed for the demo.
"""
import torch
from src.models.unet import UNet
@torch.no_grad()
def generate_next_frame(
model: UNet,
context_frames: torch.Tensor,
action: torch.Tensor,
num_steps: int = 15,
cfg_scale: float = 2.0,
) -> torch.Tensor:
"""Generate the next frame using flow matching ODE integration.
Uses Euler integration along the learned velocity field with
classifier-free guidance.
Args:
model: Trained U-Net dynamics model (in eval mode).
context_frames: [B, H*3, h, w] context frames.
action: [B] action indices.
num_steps: Number of Euler integration steps.
cfg_scale: Classifier-free guidance scale. 1.0 = no guidance.
Returns:
Predicted next frame [B, 3, h, w], clamped to [0, 1].
"""
B = context_frames.shape[0]
h, w = context_frames.shape[2], context_frames.shape[3]
device = context_frames.device
# Start from pure noise
x = torch.randn(B, 3, h, w, device=device)
dt = 1.0 / num_steps
for i in range(num_steps):
t = torch.full((B,), i * dt, device=device)
model_input = torch.cat([x, context_frames], dim=1) # [B, 15, h, w]
# Conditional velocity (with real action)
v_cond = model(model_input, t, action)
if cfg_scale != 1.0:
# Unconditional velocity (null action — zeros)
v_uncond = model(model_input, t, torch.zeros_like(action))
# CFG: steer toward conditional prediction
v = v_uncond + cfg_scale * (v_cond - v_uncond)
else:
v = v_cond
# Euler step
x = x + dt * v
return x.clamp(0, 1)