#!/usr/bin/env python3 """ EXOKERN Skill v0 — Diffusion Policy Inference =============================================== Standalone inference script for the EXOKERN Peg Insertion Diffusion Policy. Loads a trained checkpoint and provides a clean API for action generation. Usage: from inference import DiffusionPolicyInference policy = DiffusionPolicyInference("full_ft_best_model.pt", device="cuda") policy.add_observation(obs) # call each timestep actions = policy.get_actions() # returns action chunk """ import math from collections import deque, OrderedDict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # ═══════════════════════════════════════════════════════════ # MODEL (identical to training — self-contained) # ═══════════════════════════════════════════════════════════ class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, t): half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb) emb = t.float().unsqueeze(-1) * emb.unsqueeze(0) return torch.cat([emb.sin(), emb.cos()], dim=-1) class ConditionalResBlock1D(nn.Module): def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3): super().__init__() self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size // 2) self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size // 2) self.norm1 = nn.GroupNorm(8, out_channels) self.norm2 = nn.GroupNorm(8, out_channels) self.cond_proj = nn.Linear(cond_dim, out_channels * 2) self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() def forward(self, x, cond): h = self.conv1(x) h = self.norm1(h) scale, shift = self.cond_proj(cond).chunk(2, dim=-1) h = h * (1 + scale.unsqueeze(-1)) + shift.unsqueeze(-1) h = F.mish(h) h = self.conv2(h) h = self.norm2(h) h = F.mish(h) return h + self.residual_conv(x) class TemporalUNet1D(nn.Module): def __init__(self, action_dim, obs_dim, obs_horizon, base_channels=256, channel_mults=(1, 2, 4), cond_dim=256): super().__init__() self.action_dim = action_dim self.obs_dim = obs_dim self.obs_horizon = obs_horizon self.time_embed = nn.Sequential( SinusoidalPosEmb(cond_dim), nn.Linear(cond_dim, cond_dim), nn.Mish(), nn.Linear(cond_dim, cond_dim), ) self.obs_encoder = nn.Sequential( nn.Linear(obs_horizon * obs_dim, cond_dim), nn.Mish(), nn.Linear(cond_dim, cond_dim), nn.Mish(), nn.Linear(cond_dim, cond_dim), ) self.cond_proj = nn.Sequential( nn.Linear(cond_dim * 2, cond_dim), nn.Mish(), ) self.input_proj = nn.Conv1d(action_dim, base_channels, 1) self.encoder_blocks = nn.ModuleList() self.downsamples = nn.ModuleList() channels = [base_channels] ch = base_channels for mult in channel_mults: out_ch = base_channels * mult self.encoder_blocks.append(nn.ModuleList([ ConditionalResBlock1D(ch, out_ch, cond_dim), ConditionalResBlock1D(out_ch, out_ch, cond_dim), ])) self.downsamples.append(nn.Conv1d(out_ch, out_ch, 3, stride=2, padding=1)) channels.append(out_ch) ch = out_ch self.mid_block1 = ConditionalResBlock1D(ch, ch, cond_dim) self.mid_block2 = ConditionalResBlock1D(ch, ch, cond_dim) self.decoder_blocks = nn.ModuleList() self.upsamples = nn.ModuleList() for mult in reversed(channel_mults): out_ch = base_channels * mult self.upsamples.append(nn.ConvTranspose1d(ch, ch, 4, stride=2, padding=1)) self.decoder_blocks.append(nn.ModuleList([ ConditionalResBlock1D(ch + out_ch, out_ch, cond_dim), ConditionalResBlock1D(out_ch, out_ch, cond_dim), ])) ch = out_ch self.output_proj = nn.Sequential( nn.GroupNorm(8, base_channels), nn.Mish(), nn.Conv1d(base_channels, action_dim, 1), ) def forward(self, noisy_actions, timestep, obs_cond): batch_size = noisy_actions.shape[0] t_emb = self.time_embed(timestep) obs_flat = obs_cond.reshape(batch_size, -1) obs_emb = self.obs_encoder(obs_flat) cond = self.cond_proj(torch.cat([t_emb, obs_emb], dim=-1)) x = noisy_actions.permute(0, 2, 1) x = self.input_proj(x) skip_connections = [] for (res1, res2), downsample in zip(self.encoder_blocks, self.downsamples): x = res1(x, cond) x = res2(x, cond) skip_connections.append(x) x = downsample(x) x = self.mid_block1(x, cond) x = self.mid_block2(x, cond) for (res1, res2), upsample in zip(self.decoder_blocks, self.upsamples): x = upsample(x) skip = skip_connections.pop() if x.shape[-1] != skip.shape[-1]: x = x[:, :, :skip.shape[-1]] x = torch.cat([x, skip], dim=1) x = res1(x, cond) x = res2(x, cond) x = self.output_proj(x) return x.permute(0, 2, 1) # ═══════════════════════════════════════════════════════════ # DDIM SAMPLER # ═══════════════════════════════════════════════════════════ def cosine_beta_schedule(num_steps, s=0.008): steps = torch.arange(num_steps + 1, dtype=torch.float64) alphas_cumprod = torch.cos((steps / num_steps + s) / (1 + s) * math.pi / 2) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clamp(betas, 0.0001, 0.999).float() class DDIMSampler: def __init__(self, num_train_steps, num_inference_steps, device): betas = cosine_beta_schedule(num_train_steps) alphas = 1.0 - betas self.alphas_cumprod = torch.cumprod(alphas, dim=0).to(device) step_ratio = num_train_steps // num_inference_steps self.timesteps = (torch.arange(0, num_inference_steps) * step_ratio).long().to(device) self.device = device @torch.no_grad() def sample(self, model, obs_cond, shape): x = torch.randn(shape, device=self.device) timesteps = self.timesteps.flip(0) for i, t in enumerate(timesteps): t_batch = t.expand(shape[0]) noise_pred = model(x, t_batch, obs_cond) alpha_t = self.alphas_cumprod[t].view(-1, 1, 1) t_prev = timesteps[i + 1] if i + 1 < len(timesteps) else torch.tensor(-1, device=self.device) alpha_prev = self.alphas_cumprod[t_prev].view(-1, 1, 1) if t_prev >= 0 else torch.ones_like(alpha_t) x_0_pred = (x - torch.sqrt(1 - alpha_t) * noise_pred) / torch.sqrt(alpha_t) x_0_pred = torch.clamp(x_0_pred, -1, 1) x = torch.sqrt(alpha_prev) * x_0_pred + torch.sqrt(1 - alpha_prev) * noise_pred return x # ═══════════════════════════════════════════════════════════ # INFERENCE API # ═══════════════════════════════════════════════════════════ class DiffusionPolicyInference: """ Clean inference API for the EXOKERN Diffusion Policy. Usage: policy = DiffusionPolicyInference("best_model.pt", device="cuda") # Each timestep: policy.add_observation(obs_vector) # numpy or torch, shape (obs_dim,) # When action buffer is empty: actions = policy.get_actions() # returns list of numpy arrays for action in actions: env.step(action) """ def __init__(self, checkpoint_path, device="cuda"): self.device = torch.device(device) try: from safe_load import safe_load_checkpoint ckpt = safe_load_checkpoint(checkpoint_path, device=str(self.device)) except ImportError: ckpt = torch.load(checkpoint_path, map_location=self.device, weights_only=False) # Extract config self.obs_dim = ckpt["obs_dim"] self.action_dim = ckpt["action_dim"] self.condition = ckpt["condition"] self.stats = ckpt["stats"] args = ckpt.get("args", {}) self.obs_horizon = args.get("obs_horizon", 10) self.pred_horizon = args.get("pred_horizon", 16) self.action_horizon = args.get("action_horizon", 8) # Build model self.model = TemporalUNet1D( action_dim=self.action_dim, obs_dim=self.obs_dim, obs_horizon=self.obs_horizon, base_channels=args.get("base_channels", 256), channel_mults=(1, 2, 4), cond_dim=args.get("cond_dim", 256), ).to(self.device) self.model.load_state_dict(ckpt["model_state_dict"]) self.model.eval() # DDIM sampler self.sampler = DDIMSampler( num_train_steps=args.get("num_diffusion_steps", 100), num_inference_steps=args.get("num_inference_steps", 16), device=self.device, ) # Normalization tensors self.obs_min = torch.tensor(self.stats["obs_min"], dtype=torch.float32, device=self.device) self.obs_range = torch.tensor(self.stats["obs_range"], dtype=torch.float32, device=self.device) self.action_min = torch.tensor(self.stats["action_min"], dtype=torch.float32, device=self.device) self.action_range = torch.tensor(self.stats["action_range"], dtype=torch.float32, device=self.device) # Observation buffer self.obs_window = deque(maxlen=self.obs_horizon) self.action_buffer = [] print(f"Loaded EXOKERN Diffusion Policy: {self.condition}") print(f" obs_dim={self.obs_dim}, action_dim={self.action_dim}") print(f" val_loss={ckpt['val_loss']:.6f}") def _normalize_obs(self, obs): return 2.0 * (obs - self.obs_min) / self.obs_range - 1.0 def _denormalize_action(self, action_norm): return (action_norm + 1.0) / 2.0 * self.action_range + self.action_min def add_observation(self, obs): """Add a new observation. Call this every timestep.""" if isinstance(obs, np.ndarray): obs = torch.tensor(obs, dtype=torch.float32, device=self.device) if obs.dim() == 1: obs = obs.unsqueeze(0) obs_norm = self._normalize_obs(obs) self.obs_window.append(obs_norm) def needs_new_actions(self): """Returns True if the action buffer is empty and new actions should be generated.""" return len(self.action_buffer) == 0 @torch.no_grad() def get_actions(self): """ Generate new actions via DDIM sampling. Returns: List of numpy arrays, each shape (action_dim,). Execute them in order, one per timestep. """ # Pad observation window if not full while len(self.obs_window) < self.obs_horizon: self.obs_window.appendleft(self.obs_window[0]) # Build conditioning obs_seq = torch.stack(list(self.obs_window), dim=1) # (1, obs_horizon, obs_dim) # DDIM sampling shape = (1, self.pred_horizon, self.action_dim) action_traj_norm = self.sampler.sample(self.model, obs_seq, shape) # Denormalize and take action_horizon steps action_traj = self._denormalize_action(action_traj_norm[0]) # (pred_horizon, action_dim) actions = [a.cpu().numpy() for a in action_traj[:self.action_horizon]] self.action_buffer = actions[1:] # Save remaining for pop return actions def pop_action(self): """Pop and return the next action from the buffer. Returns None if empty.""" if self.action_buffer: return self.action_buffer.pop(0) return None def reset(self): """Reset observation window and action buffer for a new episode.""" self.obs_window.clear() self.action_buffer = [] if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("checkpoint", help="Path to checkpoint .pt file") parser.add_argument("--device", default="cuda") args = parser.parse_args() policy = DiffusionPolicyInference(args.checkpoint, device=args.device) # Quick sanity check: generate actions from random observations for step in range(20): dummy_obs = np.random.randn(policy.obs_dim).astype(np.float32) policy.add_observation(dummy_obs) actions = policy.get_actions() print(f"\nGenerated {len(actions)} actions:") for i, a in enumerate(actions): print(f" Action {i}: {a}")