| """Inference script for the Push-T diffusion policy checkpoints in this repo. |
| |
| Usage: |
| pip install torch diffusers pygame pymunk shapely scikit-image imageio gym numpy |
| git clone https://github.com/columbia-ai-robotics/streaming_flow_policy.git |
| |
| # original |
| python inference.py --ckpt original.ckpt --stats stats_sfp.npz \\ |
| --sfp-repo streaming_flow_policy/ --n-seeds 50 |
| |
| # edited |
| python inference.py --ckpt edited.pt --stats stats_sfp.npz \\ |
| --sfp-repo streaming_flow_policy/ --n-seeds 50 --save-mp4 out.mp4 |
| """ |
| import argparse, sys, time |
| import numpy as np |
| import torch |
| from diffusers.schedulers.scheduling_ddpm import DDPMScheduler |
|
|
| p = argparse.ArgumentParser() |
| p.add_argument("--ckpt", required=True) |
| p.add_argument("--stats", required=True) |
| p.add_argument("--sfp-repo", required=True, |
| help="path to cloned columbia-ai-robotics/streaming_flow_policy repo") |
| p.add_argument("--n-seeds", type=int, default=50) |
| p.add_argument("--seed-base", type=int, default=20000) |
| p.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu") |
| p.add_argument("--save-mp4", default="", help="optional: mp4 of first successful rollout") |
| args = p.parse_args() |
|
|
| sys.path.insert(0, args.sfp_repo) |
| from streaming_flow_policy.pusht.dp_state_notebook.env import PushTEnv |
| from streaming_flow_policy.pusht.dp_state_notebook.network import ConditionalUnet1D |
|
|
| OBS_HORIZON, OBS_DIM, ACT_DIM, PRED_HORIZON, ACT_HORIZON = 2, 5, 2, 16, 8 |
| N_DDPM, MAX_STEPS, SUCCESS_TH = 100, 300, 0.95 |
| GOAL = np.array([256.0, 256.0]) |
|
|
|
|
| def sample_init(seed): |
| rng = np.random.default_rng(seed) |
| D = rng.uniform(27.0, 63.0) |
| ang = np.deg2rad(rng.uniform(42.0, 48.0)) |
| sd = np.array([-np.sin(ang), np.cos(ang)]) |
| block_xy = GOAL + D * sd + rng.uniform(-3.0, 3.0, 2) |
| agent_xy = GOAL - D * sd + rng.uniform(-15.0, 15.0, 2) |
| return np.array([agent_xy[0], agent_xy[1], block_xy[0], block_xy[1], ang], dtype=np.float64) |
|
|
|
|
| def cog_of_T(bx, by, bth): |
| c, s = np.cos(bth), np.sin(bth) |
| return np.array([bx - s * 45, by + c * 45]) |
|
|
|
|
| def geom_mode(states_array, cap=50): |
| cap = min(cap, len(states_array)) |
| if cap < 5: return -1 |
| agent = states_array[:cap, 0:2] |
| block = states_array[:cap, 2:4] |
| th_ = states_array[:cap, 4] |
| cogs = np.stack([cog_of_T(block[t, 0], block[t, 1], th_[t]) for t in range(cap)]) |
| rel = agent - cogs |
| ang = np.unwrap(np.arctan2(rel[:, 1], rel[:, 0])) |
| netw = np.degrees(ang - ang[0]) |
| if abs(netw.min()) > abs(netw.max()) + 30.0: return 1 |
| if abs(netw.max()) > abs(netw.min()) + 30.0: return 0 |
| return -1 |
|
|
|
|
| def main(): |
| device = args.device |
| stats = np.load(args.stats) |
| obs_min, obs_max = stats["obs_min"], stats["obs_max"] |
| act_min = stats["act_min"] if "act_min" in stats.files else stats["action_min"] |
| act_max = stats["act_max"] if "act_max" in stats.files else stats["action_max"] |
|
|
| print(f"Loading ckpt: {args.ckpt}", flush=True) |
| model = ConditionalUnet1D(input_dim=ACT_DIM, global_cond_dim=OBS_DIM * OBS_HORIZON) |
| sd = torch.load(args.ckpt, map_location=device, weights_only=False) |
| if isinstance(sd, dict): |
| if "ema_state_dict" in sd: sd = sd["ema_state_dict"] |
| elif "state_dict" in sd: sd = sd["state_dict"] |
| model.load_state_dict(sd); model.to(device).eval() |
| for pp in model.parameters(): pp.requires_grad_(False) |
|
|
| sched = DDPMScheduler(num_train_timesteps=N_DDPM, beta_schedule="squaredcos_cap_v2", |
| clip_sample=True, prediction_type="epsilon") |
|
|
| def nz(x): return (x - obs_min) / (obs_max - obs_min) * 2 - 1 |
|
|
| seeds = list(range(args.seed_base, args.seed_base + args.n_seeds)) |
| t0 = time.time(); n_succ = 0; modes = {0: 0, 1: 0, -1: 0} |
| first_succ_states = None; first_succ_init = None |
| for seed in seeds: |
| init = sample_init(seed) |
| env = PushTEnv(reset_to_state=init.copy()); env.seed(seed); o, _ = env.reset() |
| hist = [o.copy()] * OBS_HORIZON |
| states = [o.copy()]; rewards = [] |
| chunk = None; step = 0 |
| while step < MAX_STEPS: |
| if step % ACT_HORIZON == 0: |
| oc_np = np.stack(hist[-OBS_HORIZON:]) |
| oc = torch.from_numpy(nz(oc_np).astype(np.float32)).to(device).flatten().unsqueeze(0) |
| a = torch.randn((1, PRED_HORIZON, ACT_DIM), |
| generator=torch.Generator(device=device).manual_seed(int(seed + step*1000 + 7919)), |
| device=device) |
| sched.set_timesteps(N_DDPM) |
| with torch.no_grad(): |
| for k in sched.timesteps: |
| eps = model(sample=a, timestep=k, global_cond=oc) |
| a = sched.step(model_output=eps, timestep=k, sample=a).prev_sample |
| chunk = ((a.cpu().numpy()[0] + 1) / 2 * (act_max - act_min) + act_min) |
| cs = step % ACT_HORIZON |
| act = chunk[cs + OBS_HORIZON - 1] |
| o, r, term, trunc, _ = env.step(act); hist.append(o.copy()); states.append(o.copy()); rewards.append(float(r)) |
| step += 1 |
| if term or trunc or r >= SUCCESS_TH: break |
| success = (max(rewards) if rewards else 0.0) >= SUCCESS_TH |
| states_arr = np.asarray(states) |
| mode = geom_mode(states_arr); modes[mode] += 1 |
| if success: n_succ += 1 |
| if success and first_succ_states is None and args.save_mp4: |
| first_succ_states = states_arr; first_succ_init = init |
| mode_str = {0: "LEFT", 1: "RIGHT", -1: "mixed"}[mode] |
| print(f" seed {seed}: {'OK' if success else 'FAIL'} mode={mode_str} steps={len(states_arr)-1}", flush=True) |
|
|
| print(f"\n=== {args.n_seeds} seeds ({time.time()-t0:.1f}s) ===") |
| print(f"SR @ coverage>={SUCCESS_TH}: {n_succ}/{args.n_seeds} = {100*n_succ/args.n_seeds:.1f}%") |
| print(f"mode dist: LEFT={modes[0]} RIGHT={modes[1]} mixed={modes[-1]}") |
|
|
| if args.save_mp4 and first_succ_states is not None: |
| import imageio.v2 as iio |
| env = PushTEnv(reset_to_state=first_succ_init.copy(), render_size=512) |
| env.seed(0); env.reset() |
| frames = [] |
| for s in first_succ_states: |
| env._set_state(s); frames.append(env.render("rgb_array").copy()) |
| iio.mimsave(args.save_mp4, frames, fps=10) |
| print(f"saved {args.save_mp4} ({len(frames)} frames)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|