| | |
| | """ |
| | EXOKERN Skill v0 — Diffusion Policy Inference |
| | =============================================== |
| | |
| | Standalone inference script for the EXOKERN Peg Insertion Diffusion Policy. |
| | Loads a trained checkpoint and provides a clean API for action generation. |
| | |
| | Usage: |
| | from inference import DiffusionPolicyInference |
| | |
| | policy = DiffusionPolicyInference("full_ft_best_model.pt", device="cuda") |
| | policy.add_observation(obs) # call each timestep |
| | actions = policy.get_actions() # returns action chunk |
| | """ |
| |
|
| | import math |
| | from collections import deque, OrderedDict |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class SinusoidalPosEmb(nn.Module): |
| | def __init__(self, dim): |
| | super().__init__() |
| | self.dim = dim |
| |
|
| | def forward(self, t): |
| | half_dim = self.dim // 2 |
| | emb = math.log(10000) / (half_dim - 1) |
| | emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb) |
| | emb = t.float().unsqueeze(-1) * emb.unsqueeze(0) |
| | return torch.cat([emb.sin(), emb.cos()], dim=-1) |
| |
|
| |
|
| | class ConditionalResBlock1D(nn.Module): |
| | def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3): |
| | super().__init__() |
| | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size // 2) |
| | self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size // 2) |
| | self.norm1 = nn.GroupNorm(8, out_channels) |
| | self.norm2 = nn.GroupNorm(8, out_channels) |
| | self.cond_proj = nn.Linear(cond_dim, out_channels * 2) |
| | self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() |
| |
|
| | def forward(self, x, cond): |
| | h = self.conv1(x) |
| | h = self.norm1(h) |
| | scale, shift = self.cond_proj(cond).chunk(2, dim=-1) |
| | h = h * (1 + scale.unsqueeze(-1)) + shift.unsqueeze(-1) |
| | h = F.mish(h) |
| | h = self.conv2(h) |
| | h = self.norm2(h) |
| | h = F.mish(h) |
| | return h + self.residual_conv(x) |
| |
|
| |
|
| | class TemporalUNet1D(nn.Module): |
| | def __init__(self, action_dim, obs_dim, obs_horizon, |
| | base_channels=256, channel_mults=(1, 2, 4), cond_dim=256): |
| | super().__init__() |
| | self.action_dim = action_dim |
| | self.obs_dim = obs_dim |
| | self.obs_horizon = obs_horizon |
| |
|
| | self.time_embed = nn.Sequential( |
| | SinusoidalPosEmb(cond_dim), |
| | nn.Linear(cond_dim, cond_dim), nn.Mish(), |
| | nn.Linear(cond_dim, cond_dim), |
| | ) |
| | self.obs_encoder = nn.Sequential( |
| | nn.Linear(obs_horizon * obs_dim, cond_dim), nn.Mish(), |
| | nn.Linear(cond_dim, cond_dim), nn.Mish(), |
| | nn.Linear(cond_dim, cond_dim), |
| | ) |
| | self.cond_proj = nn.Sequential( |
| | nn.Linear(cond_dim * 2, cond_dim), nn.Mish(), |
| | ) |
| | self.input_proj = nn.Conv1d(action_dim, base_channels, 1) |
| |
|
| | self.encoder_blocks = nn.ModuleList() |
| | self.downsamples = nn.ModuleList() |
| | channels = [base_channels] |
| | ch = base_channels |
| | for mult in channel_mults: |
| | out_ch = base_channels * mult |
| | self.encoder_blocks.append(nn.ModuleList([ |
| | ConditionalResBlock1D(ch, out_ch, cond_dim), |
| | ConditionalResBlock1D(out_ch, out_ch, cond_dim), |
| | ])) |
| | self.downsamples.append(nn.Conv1d(out_ch, out_ch, 3, stride=2, padding=1)) |
| | channels.append(out_ch) |
| | ch = out_ch |
| |
|
| | self.mid_block1 = ConditionalResBlock1D(ch, ch, cond_dim) |
| | self.mid_block2 = ConditionalResBlock1D(ch, ch, cond_dim) |
| |
|
| | self.decoder_blocks = nn.ModuleList() |
| | self.upsamples = nn.ModuleList() |
| | for mult in reversed(channel_mults): |
| | out_ch = base_channels * mult |
| | self.upsamples.append(nn.ConvTranspose1d(ch, ch, 4, stride=2, padding=1)) |
| | self.decoder_blocks.append(nn.ModuleList([ |
| | ConditionalResBlock1D(ch + out_ch, out_ch, cond_dim), |
| | ConditionalResBlock1D(out_ch, out_ch, cond_dim), |
| | ])) |
| | ch = out_ch |
| |
|
| | self.output_proj = nn.Sequential( |
| | nn.GroupNorm(8, base_channels), nn.Mish(), |
| | nn.Conv1d(base_channels, action_dim, 1), |
| | ) |
| |
|
| | def forward(self, noisy_actions, timestep, obs_cond): |
| | batch_size = noisy_actions.shape[0] |
| | t_emb = self.time_embed(timestep) |
| | obs_flat = obs_cond.reshape(batch_size, -1) |
| | obs_emb = self.obs_encoder(obs_flat) |
| | cond = self.cond_proj(torch.cat([t_emb, obs_emb], dim=-1)) |
| |
|
| | x = noisy_actions.permute(0, 2, 1) |
| | x = self.input_proj(x) |
| |
|
| | skip_connections = [] |
| | for (res1, res2), downsample in zip(self.encoder_blocks, self.downsamples): |
| | x = res1(x, cond) |
| | x = res2(x, cond) |
| | skip_connections.append(x) |
| | x = downsample(x) |
| |
|
| | x = self.mid_block1(x, cond) |
| | x = self.mid_block2(x, cond) |
| |
|
| | for (res1, res2), upsample in zip(self.decoder_blocks, self.upsamples): |
| | x = upsample(x) |
| | skip = skip_connections.pop() |
| | if x.shape[-1] != skip.shape[-1]: |
| | x = x[:, :, :skip.shape[-1]] |
| | x = torch.cat([x, skip], dim=1) |
| | x = res1(x, cond) |
| | x = res2(x, cond) |
| |
|
| | x = self.output_proj(x) |
| | return x.permute(0, 2, 1) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def cosine_beta_schedule(num_steps, s=0.008): |
| | steps = torch.arange(num_steps + 1, dtype=torch.float64) |
| | alphas_cumprod = torch.cos((steps / num_steps + s) / (1 + s) * math.pi / 2) ** 2 |
| | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
| | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
| | return torch.clamp(betas, 0.0001, 0.999).float() |
| |
|
| |
|
| | class DDIMSampler: |
| | def __init__(self, num_train_steps, num_inference_steps, device): |
| | betas = cosine_beta_schedule(num_train_steps) |
| | alphas = 1.0 - betas |
| | self.alphas_cumprod = torch.cumprod(alphas, dim=0).to(device) |
| | step_ratio = num_train_steps // num_inference_steps |
| | self.timesteps = (torch.arange(0, num_inference_steps) * step_ratio).long().to(device) |
| | self.device = device |
| |
|
| | @torch.no_grad() |
| | def sample(self, model, obs_cond, shape): |
| | x = torch.randn(shape, device=self.device) |
| | timesteps = self.timesteps.flip(0) |
| | for i, t in enumerate(timesteps): |
| | t_batch = t.expand(shape[0]) |
| | noise_pred = model(x, t_batch, obs_cond) |
| | alpha_t = self.alphas_cumprod[t].view(-1, 1, 1) |
| | t_prev = timesteps[i + 1] if i + 1 < len(timesteps) else torch.tensor(-1, device=self.device) |
| | alpha_prev = self.alphas_cumprod[t_prev].view(-1, 1, 1) if t_prev >= 0 else torch.ones_like(alpha_t) |
| | x_0_pred = (x - torch.sqrt(1 - alpha_t) * noise_pred) / torch.sqrt(alpha_t) |
| | x_0_pred = torch.clamp(x_0_pred, -1, 1) |
| | x = torch.sqrt(alpha_prev) * x_0_pred + torch.sqrt(1 - alpha_prev) * noise_pred |
| | return x |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class DiffusionPolicyInference: |
| | """ |
| | Clean inference API for the EXOKERN Diffusion Policy. |
| | |
| | Usage: |
| | policy = DiffusionPolicyInference("best_model.pt", device="cuda") |
| | |
| | # Each timestep: |
| | policy.add_observation(obs_vector) # numpy or torch, shape (obs_dim,) |
| | |
| | # When action buffer is empty: |
| | actions = policy.get_actions() # returns list of numpy arrays |
| | for action in actions: |
| | env.step(action) |
| | """ |
| |
|
| | def __init__(self, checkpoint_path, device="cuda"): |
| | self.device = torch.device(device) |
| | try: |
| | from safe_load import safe_load_checkpoint |
| | ckpt = safe_load_checkpoint(checkpoint_path, device=str(self.device)) |
| | except ImportError: |
| | ckpt = torch.load(checkpoint_path, map_location=self.device, weights_only=False) |
| |
|
| | |
| | self.obs_dim = ckpt["obs_dim"] |
| | self.action_dim = ckpt["action_dim"] |
| | self.condition = ckpt["condition"] |
| | self.stats = ckpt["stats"] |
| | args = ckpt.get("args", {}) |
| |
|
| | self.obs_horizon = args.get("obs_horizon", 10) |
| | self.pred_horizon = args.get("pred_horizon", 16) |
| | self.action_horizon = args.get("action_horizon", 8) |
| |
|
| | |
| | self.model = TemporalUNet1D( |
| | action_dim=self.action_dim, |
| | obs_dim=self.obs_dim, |
| | obs_horizon=self.obs_horizon, |
| | base_channels=args.get("base_channels", 256), |
| | channel_mults=(1, 2, 4), |
| | cond_dim=args.get("cond_dim", 256), |
| | ).to(self.device) |
| | self.model.load_state_dict(ckpt["model_state_dict"]) |
| | self.model.eval() |
| |
|
| | |
| | self.sampler = DDIMSampler( |
| | num_train_steps=args.get("num_diffusion_steps", 100), |
| | num_inference_steps=args.get("num_inference_steps", 16), |
| | device=self.device, |
| | ) |
| |
|
| | |
| | self.obs_min = torch.tensor(self.stats["obs_min"], dtype=torch.float32, device=self.device) |
| | self.obs_range = torch.tensor(self.stats["obs_range"], dtype=torch.float32, device=self.device) |
| | self.action_min = torch.tensor(self.stats["action_min"], dtype=torch.float32, device=self.device) |
| | self.action_range = torch.tensor(self.stats["action_range"], dtype=torch.float32, device=self.device) |
| |
|
| | |
| | self.obs_window = deque(maxlen=self.obs_horizon) |
| | self.action_buffer = [] |
| |
|
| | print(f"Loaded EXOKERN Diffusion Policy: {self.condition}") |
| | print(f" obs_dim={self.obs_dim}, action_dim={self.action_dim}") |
| | print(f" val_loss={ckpt['val_loss']:.6f}") |
| |
|
| | def _normalize_obs(self, obs): |
| | return 2.0 * (obs - self.obs_min) / self.obs_range - 1.0 |
| |
|
| | def _denormalize_action(self, action_norm): |
| | return (action_norm + 1.0) / 2.0 * self.action_range + self.action_min |
| |
|
| | def add_observation(self, obs): |
| | """Add a new observation. Call this every timestep.""" |
| | if isinstance(obs, np.ndarray): |
| | obs = torch.tensor(obs, dtype=torch.float32, device=self.device) |
| | if obs.dim() == 1: |
| | obs = obs.unsqueeze(0) |
| | obs_norm = self._normalize_obs(obs) |
| | self.obs_window.append(obs_norm) |
| |
|
| | def needs_new_actions(self): |
| | """Returns True if the action buffer is empty and new actions should be generated.""" |
| | return len(self.action_buffer) == 0 |
| |
|
| | @torch.no_grad() |
| | def get_actions(self): |
| | """ |
| | Generate new actions via DDIM sampling. |
| | |
| | Returns: |
| | List of numpy arrays, each shape (action_dim,). |
| | Execute them in order, one per timestep. |
| | """ |
| | |
| | while len(self.obs_window) < self.obs_horizon: |
| | self.obs_window.appendleft(self.obs_window[0]) |
| |
|
| | |
| | obs_seq = torch.stack(list(self.obs_window), dim=1) |
| |
|
| | |
| | shape = (1, self.pred_horizon, self.action_dim) |
| | action_traj_norm = self.sampler.sample(self.model, obs_seq, shape) |
| |
|
| | |
| | action_traj = self._denormalize_action(action_traj_norm[0]) |
| | actions = [a.cpu().numpy() for a in action_traj[:self.action_horizon]] |
| |
|
| | self.action_buffer = actions[1:] |
| | return actions |
| |
|
| | def pop_action(self): |
| | """Pop and return the next action from the buffer. Returns None if empty.""" |
| | if self.action_buffer: |
| | return self.action_buffer.pop(0) |
| | return None |
| |
|
| | def reset(self): |
| | """Reset observation window and action buffer for a new episode.""" |
| | self.obs_window.clear() |
| | self.action_buffer = [] |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("checkpoint", help="Path to checkpoint .pt file") |
| | parser.add_argument("--device", default="cuda") |
| | args = parser.parse_args() |
| |
|
| | policy = DiffusionPolicyInference(args.checkpoint, device=args.device) |
| |
|
| | |
| | for step in range(20): |
| | dummy_obs = np.random.randn(policy.obs_dim).astype(np.float32) |
| | policy.add_observation(dummy_obs) |
| |
|
| | actions = policy.get_actions() |
| | print(f"\nGenerated {len(actions)} actions:") |
| | for i, a in enumerate(actions): |
| | print(f" Action {i}: {a}") |
| |
|