"""Inference script for the Push-T diffusion policy checkpoints in this repo. Usage: pip install torch diffusers pygame pymunk shapely scikit-image imageio gym numpy git clone https://github.com/columbia-ai-robotics/streaming_flow_policy.git # original python inference.py --ckpt original.ckpt --stats stats_sfp.npz \\ --sfp-repo streaming_flow_policy/ --n-seeds 50 # edited python inference.py --ckpt edited.pt --stats stats_sfp.npz \\ --sfp-repo streaming_flow_policy/ --n-seeds 50 --save-mp4 out.mp4 """ import argparse, sys, time import numpy as np import torch from diffusers.schedulers.scheduling_ddpm import DDPMScheduler p = argparse.ArgumentParser() p.add_argument("--ckpt", required=True) p.add_argument("--stats", required=True) p.add_argument("--sfp-repo", required=True, help="path to cloned columbia-ai-robotics/streaming_flow_policy repo") p.add_argument("--n-seeds", type=int, default=50) p.add_argument("--seed-base", type=int, default=20000) p.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu") p.add_argument("--save-mp4", default="", help="optional: mp4 of first successful rollout") args = p.parse_args() sys.path.insert(0, args.sfp_repo) from streaming_flow_policy.pusht.dp_state_notebook.env import PushTEnv from streaming_flow_policy.pusht.dp_state_notebook.network import ConditionalUnet1D OBS_HORIZON, OBS_DIM, ACT_DIM, PRED_HORIZON, ACT_HORIZON = 2, 5, 2, 16, 8 N_DDPM, MAX_STEPS, SUCCESS_TH = 100, 300, 0.95 GOAL = np.array([256.0, 256.0]) def sample_init(seed): rng = np.random.default_rng(seed) D = rng.uniform(27.0, 63.0) ang = np.deg2rad(rng.uniform(42.0, 48.0)) sd = np.array([-np.sin(ang), np.cos(ang)]) block_xy = GOAL + D * sd + rng.uniform(-3.0, 3.0, 2) agent_xy = GOAL - D * sd + rng.uniform(-15.0, 15.0, 2) return np.array([agent_xy[0], agent_xy[1], block_xy[0], block_xy[1], ang], dtype=np.float64) def cog_of_T(bx, by, bth): c, s = np.cos(bth), np.sin(bth) return np.array([bx - s * 45, by + c * 45]) def geom_mode(states_array, cap=50): cap = min(cap, len(states_array)) if cap < 5: return -1 agent = states_array[:cap, 0:2] block = states_array[:cap, 2:4] th_ = states_array[:cap, 4] cogs = np.stack([cog_of_T(block[t, 0], block[t, 1], th_[t]) for t in range(cap)]) rel = agent - cogs ang = np.unwrap(np.arctan2(rel[:, 1], rel[:, 0])) netw = np.degrees(ang - ang[0]) if abs(netw.min()) > abs(netw.max()) + 30.0: return 1 if abs(netw.max()) > abs(netw.min()) + 30.0: return 0 return -1 def main(): device = args.device stats = np.load(args.stats) obs_min, obs_max = stats["obs_min"], stats["obs_max"] act_min = stats["act_min"] if "act_min" in stats.files else stats["action_min"] act_max = stats["act_max"] if "act_max" in stats.files else stats["action_max"] print(f"Loading ckpt: {args.ckpt}", flush=True) model = ConditionalUnet1D(input_dim=ACT_DIM, global_cond_dim=OBS_DIM * OBS_HORIZON) sd = torch.load(args.ckpt, map_location=device, weights_only=False) if isinstance(sd, dict): if "ema_state_dict" in sd: sd = sd["ema_state_dict"] elif "state_dict" in sd: sd = sd["state_dict"] model.load_state_dict(sd); model.to(device).eval() for pp in model.parameters(): pp.requires_grad_(False) sched = DDPMScheduler(num_train_timesteps=N_DDPM, beta_schedule="squaredcos_cap_v2", clip_sample=True, prediction_type="epsilon") def nz(x): return (x - obs_min) / (obs_max - obs_min) * 2 - 1 seeds = list(range(args.seed_base, args.seed_base + args.n_seeds)) t0 = time.time(); n_succ = 0; modes = {0: 0, 1: 0, -1: 0} first_succ_states = None; first_succ_init = None for seed in seeds: init = sample_init(seed) env = PushTEnv(reset_to_state=init.copy()); env.seed(seed); o, _ = env.reset() hist = [o.copy()] * OBS_HORIZON states = [o.copy()]; rewards = [] chunk = None; step = 0 while step < MAX_STEPS: if step % ACT_HORIZON == 0: oc_np = np.stack(hist[-OBS_HORIZON:]) oc = torch.from_numpy(nz(oc_np).astype(np.float32)).to(device).flatten().unsqueeze(0) a = torch.randn((1, PRED_HORIZON, ACT_DIM), generator=torch.Generator(device=device).manual_seed(int(seed + step*1000 + 7919)), device=device) sched.set_timesteps(N_DDPM) with torch.no_grad(): for k in sched.timesteps: eps = model(sample=a, timestep=k, global_cond=oc) a = sched.step(model_output=eps, timestep=k, sample=a).prev_sample chunk = ((a.cpu().numpy()[0] + 1) / 2 * (act_max - act_min) + act_min) cs = step % ACT_HORIZON act = chunk[cs + OBS_HORIZON - 1] o, r, term, trunc, _ = env.step(act); hist.append(o.copy()); states.append(o.copy()); rewards.append(float(r)) step += 1 if term or trunc or r >= SUCCESS_TH: break success = (max(rewards) if rewards else 0.0) >= SUCCESS_TH states_arr = np.asarray(states) mode = geom_mode(states_arr); modes[mode] += 1 if success: n_succ += 1 if success and first_succ_states is None and args.save_mp4: first_succ_states = states_arr; first_succ_init = init mode_str = {0: "LEFT", 1: "RIGHT", -1: "mixed"}[mode] print(f" seed {seed}: {'OK' if success else 'FAIL'} mode={mode_str} steps={len(states_arr)-1}", flush=True) print(f"\n=== {args.n_seeds} seeds ({time.time()-t0:.1f}s) ===") print(f"SR @ coverage>={SUCCESS_TH}: {n_succ}/{args.n_seeds} = {100*n_succ/args.n_seeds:.1f}%") print(f"mode dist: LEFT={modes[0]} RIGHT={modes[1]} mixed={modes[-1]}") if args.save_mp4 and first_succ_states is not None: import imageio.v2 as iio env = PushTEnv(reset_to_state=first_succ_init.copy(), render_size=512) env.seed(0); env.reset() frames = [] for s in first_succ_states: env._set_state(s); frames.append(env.render("rgb_array").copy()) iio.mimsave(args.save_mp4, frames, fps=10) print(f"saved {args.save_mp4} ({len(frames)} frames)") if __name__ == "__main__": main()