EXOKERN1's picture
EXOKERN Skill v0.1.1: Diffusion Policy trained on DR dataset v0.1.1
95f0c27 verified
#!/usr/bin/env python3
"""
EXOKERN Skill v0 — Diffusion Policy Inference
===============================================
Standalone inference script for the EXOKERN Peg Insertion Diffusion Policy.
Loads a trained checkpoint and provides a clean API for action generation.
Usage:
from inference import DiffusionPolicyInference
policy = DiffusionPolicyInference("full_ft_best_model.pt", device="cuda")
policy.add_observation(obs) # call each timestep
actions = policy.get_actions() # returns action chunk
"""
import math
from collections import deque, OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# ═══════════════════════════════════════════════════════════
# MODEL (identical to training — self-contained)
# ═══════════════════════════════════════════════════════════
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t.float().unsqueeze(-1) * emb.unsqueeze(0)
return torch.cat([emb.sin(), emb.cos()], dim=-1)
class ConditionalResBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3):
super().__init__()
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.norm1 = nn.GroupNorm(8, out_channels)
self.norm2 = nn.GroupNorm(8, out_channels)
self.cond_proj = nn.Linear(cond_dim, out_channels * 2)
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
def forward(self, x, cond):
h = self.conv1(x)
h = self.norm1(h)
scale, shift = self.cond_proj(cond).chunk(2, dim=-1)
h = h * (1 + scale.unsqueeze(-1)) + shift.unsqueeze(-1)
h = F.mish(h)
h = self.conv2(h)
h = self.norm2(h)
h = F.mish(h)
return h + self.residual_conv(x)
class TemporalUNet1D(nn.Module):
def __init__(self, action_dim, obs_dim, obs_horizon,
base_channels=256, channel_mults=(1, 2, 4), cond_dim=256):
super().__init__()
self.action_dim = action_dim
self.obs_dim = obs_dim
self.obs_horizon = obs_horizon
self.time_embed = nn.Sequential(
SinusoidalPosEmb(cond_dim),
nn.Linear(cond_dim, cond_dim), nn.Mish(),
nn.Linear(cond_dim, cond_dim),
)
self.obs_encoder = nn.Sequential(
nn.Linear(obs_horizon * obs_dim, cond_dim), nn.Mish(),
nn.Linear(cond_dim, cond_dim), nn.Mish(),
nn.Linear(cond_dim, cond_dim),
)
self.cond_proj = nn.Sequential(
nn.Linear(cond_dim * 2, cond_dim), nn.Mish(),
)
self.input_proj = nn.Conv1d(action_dim, base_channels, 1)
self.encoder_blocks = nn.ModuleList()
self.downsamples = nn.ModuleList()
channels = [base_channels]
ch = base_channels
for mult in channel_mults:
out_ch = base_channels * mult
self.encoder_blocks.append(nn.ModuleList([
ConditionalResBlock1D(ch, out_ch, cond_dim),
ConditionalResBlock1D(out_ch, out_ch, cond_dim),
]))
self.downsamples.append(nn.Conv1d(out_ch, out_ch, 3, stride=2, padding=1))
channels.append(out_ch)
ch = out_ch
self.mid_block1 = ConditionalResBlock1D(ch, ch, cond_dim)
self.mid_block2 = ConditionalResBlock1D(ch, ch, cond_dim)
self.decoder_blocks = nn.ModuleList()
self.upsamples = nn.ModuleList()
for mult in reversed(channel_mults):
out_ch = base_channels * mult
self.upsamples.append(nn.ConvTranspose1d(ch, ch, 4, stride=2, padding=1))
self.decoder_blocks.append(nn.ModuleList([
ConditionalResBlock1D(ch + out_ch, out_ch, cond_dim),
ConditionalResBlock1D(out_ch, out_ch, cond_dim),
]))
ch = out_ch
self.output_proj = nn.Sequential(
nn.GroupNorm(8, base_channels), nn.Mish(),
nn.Conv1d(base_channels, action_dim, 1),
)
def forward(self, noisy_actions, timestep, obs_cond):
batch_size = noisy_actions.shape[0]
t_emb = self.time_embed(timestep)
obs_flat = obs_cond.reshape(batch_size, -1)
obs_emb = self.obs_encoder(obs_flat)
cond = self.cond_proj(torch.cat([t_emb, obs_emb], dim=-1))
x = noisy_actions.permute(0, 2, 1)
x = self.input_proj(x)
skip_connections = []
for (res1, res2), downsample in zip(self.encoder_blocks, self.downsamples):
x = res1(x, cond)
x = res2(x, cond)
skip_connections.append(x)
x = downsample(x)
x = self.mid_block1(x, cond)
x = self.mid_block2(x, cond)
for (res1, res2), upsample in zip(self.decoder_blocks, self.upsamples):
x = upsample(x)
skip = skip_connections.pop()
if x.shape[-1] != skip.shape[-1]:
x = x[:, :, :skip.shape[-1]]
x = torch.cat([x, skip], dim=1)
x = res1(x, cond)
x = res2(x, cond)
x = self.output_proj(x)
return x.permute(0, 2, 1)
# ═══════════════════════════════════════════════════════════
# DDIM SAMPLER
# ═══════════════════════════════════════════════════════════
def cosine_beta_schedule(num_steps, s=0.008):
steps = torch.arange(num_steps + 1, dtype=torch.float64)
alphas_cumprod = torch.cos((steps / num_steps + s) / (1 + s) * math.pi / 2) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clamp(betas, 0.0001, 0.999).float()
class DDIMSampler:
def __init__(self, num_train_steps, num_inference_steps, device):
betas = cosine_beta_schedule(num_train_steps)
alphas = 1.0 - betas
self.alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
step_ratio = num_train_steps // num_inference_steps
self.timesteps = (torch.arange(0, num_inference_steps) * step_ratio).long().to(device)
self.device = device
@torch.no_grad()
def sample(self, model, obs_cond, shape):
x = torch.randn(shape, device=self.device)
timesteps = self.timesteps.flip(0)
for i, t in enumerate(timesteps):
t_batch = t.expand(shape[0])
noise_pred = model(x, t_batch, obs_cond)
alpha_t = self.alphas_cumprod[t].view(-1, 1, 1)
t_prev = timesteps[i + 1] if i + 1 < len(timesteps) else torch.tensor(-1, device=self.device)
alpha_prev = self.alphas_cumprod[t_prev].view(-1, 1, 1) if t_prev >= 0 else torch.ones_like(alpha_t)
x_0_pred = (x - torch.sqrt(1 - alpha_t) * noise_pred) / torch.sqrt(alpha_t)
x_0_pred = torch.clamp(x_0_pred, -1, 1)
x = torch.sqrt(alpha_prev) * x_0_pred + torch.sqrt(1 - alpha_prev) * noise_pred
return x
# ═══════════════════════════════════════════════════════════
# INFERENCE API
# ═══════════════════════════════════════════════════════════
class DiffusionPolicyInference:
"""
Clean inference API for the EXOKERN Diffusion Policy.
Usage:
policy = DiffusionPolicyInference("best_model.pt", device="cuda")
# Each timestep:
policy.add_observation(obs_vector) # numpy or torch, shape (obs_dim,)
# When action buffer is empty:
actions = policy.get_actions() # returns list of numpy arrays
for action in actions:
env.step(action)
"""
def __init__(self, checkpoint_path, device="cuda"):
self.device = torch.device(device)
try:
from safe_load import safe_load_checkpoint
ckpt = safe_load_checkpoint(checkpoint_path, device=str(self.device))
except ImportError:
ckpt = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
# Extract config
self.obs_dim = ckpt["obs_dim"]
self.action_dim = ckpt["action_dim"]
self.condition = ckpt["condition"]
self.stats = ckpt["stats"]
args = ckpt.get("args", {})
self.obs_horizon = args.get("obs_horizon", 10)
self.pred_horizon = args.get("pred_horizon", 16)
self.action_horizon = args.get("action_horizon", 8)
# Build model
self.model = TemporalUNet1D(
action_dim=self.action_dim,
obs_dim=self.obs_dim,
obs_horizon=self.obs_horizon,
base_channels=args.get("base_channels", 256),
channel_mults=(1, 2, 4),
cond_dim=args.get("cond_dim", 256),
).to(self.device)
self.model.load_state_dict(ckpt["model_state_dict"])
self.model.eval()
# DDIM sampler
self.sampler = DDIMSampler(
num_train_steps=args.get("num_diffusion_steps", 100),
num_inference_steps=args.get("num_inference_steps", 16),
device=self.device,
)
# Normalization tensors
self.obs_min = torch.tensor(self.stats["obs_min"], dtype=torch.float32, device=self.device)
self.obs_range = torch.tensor(self.stats["obs_range"], dtype=torch.float32, device=self.device)
self.action_min = torch.tensor(self.stats["action_min"], dtype=torch.float32, device=self.device)
self.action_range = torch.tensor(self.stats["action_range"], dtype=torch.float32, device=self.device)
# Observation buffer
self.obs_window = deque(maxlen=self.obs_horizon)
self.action_buffer = []
print(f"Loaded EXOKERN Diffusion Policy: {self.condition}")
print(f" obs_dim={self.obs_dim}, action_dim={self.action_dim}")
print(f" val_loss={ckpt['val_loss']:.6f}")
def _normalize_obs(self, obs):
return 2.0 * (obs - self.obs_min) / self.obs_range - 1.0
def _denormalize_action(self, action_norm):
return (action_norm + 1.0) / 2.0 * self.action_range + self.action_min
def add_observation(self, obs):
"""Add a new observation. Call this every timestep."""
if isinstance(obs, np.ndarray):
obs = torch.tensor(obs, dtype=torch.float32, device=self.device)
if obs.dim() == 1:
obs = obs.unsqueeze(0)
obs_norm = self._normalize_obs(obs)
self.obs_window.append(obs_norm)
def needs_new_actions(self):
"""Returns True if the action buffer is empty and new actions should be generated."""
return len(self.action_buffer) == 0
@torch.no_grad()
def get_actions(self):
"""
Generate new actions via DDIM sampling.
Returns:
List of numpy arrays, each shape (action_dim,).
Execute them in order, one per timestep.
"""
# Pad observation window if not full
while len(self.obs_window) < self.obs_horizon:
self.obs_window.appendleft(self.obs_window[0])
# Build conditioning
obs_seq = torch.stack(list(self.obs_window), dim=1) # (1, obs_horizon, obs_dim)
# DDIM sampling
shape = (1, self.pred_horizon, self.action_dim)
action_traj_norm = self.sampler.sample(self.model, obs_seq, shape)
# Denormalize and take action_horizon steps
action_traj = self._denormalize_action(action_traj_norm[0]) # (pred_horizon, action_dim)
actions = [a.cpu().numpy() for a in action_traj[:self.action_horizon]]
self.action_buffer = actions[1:] # Save remaining for pop
return actions
def pop_action(self):
"""Pop and return the next action from the buffer. Returns None if empty."""
if self.action_buffer:
return self.action_buffer.pop(0)
return None
def reset(self):
"""Reset observation window and action buffer for a new episode."""
self.obs_window.clear()
self.action_buffer = []
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("checkpoint", help="Path to checkpoint .pt file")
parser.add_argument("--device", default="cuda")
args = parser.parse_args()
policy = DiffusionPolicyInference(args.checkpoint, device=args.device)
# Quick sanity check: generate actions from random observations
for step in range(20):
dummy_obs = np.random.randn(policy.obs_dim).astype(np.float32)
policy.add_observation(dummy_obs)
actions = policy.get_actions()
print(f"\nGenerated {len(actions)} actions:")
for i, a in enumerate(actions):
print(f" Action {i}: {a}")