pusht-dp-checkpoints / inference.py
haohw's picture
add inference.py
bd16683 verified
Raw
History Blame Contribute Delete
6.42 kB
"""Inference script for the Push-T diffusion policy checkpoints in this repo.
Usage:
pip install torch diffusers pygame pymunk shapely scikit-image imageio gym numpy
git clone https://github.com/columbia-ai-robotics/streaming_flow_policy.git
# original
python inference.py --ckpt original.ckpt --stats stats_sfp.npz \\
--sfp-repo streaming_flow_policy/ --n-seeds 50
# edited
python inference.py --ckpt edited.pt --stats stats_sfp.npz \\
--sfp-repo streaming_flow_policy/ --n-seeds 50 --save-mp4 out.mp4
"""
import argparse, sys, time
import numpy as np
import torch
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
p = argparse.ArgumentParser()
p.add_argument("--ckpt", required=True)
p.add_argument("--stats", required=True)
p.add_argument("--sfp-repo", required=True,
help="path to cloned columbia-ai-robotics/streaming_flow_policy repo")
p.add_argument("--n-seeds", type=int, default=50)
p.add_argument("--seed-base", type=int, default=20000)
p.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu")
p.add_argument("--save-mp4", default="", help="optional: mp4 of first successful rollout")
args = p.parse_args()
sys.path.insert(0, args.sfp_repo)
from streaming_flow_policy.pusht.dp_state_notebook.env import PushTEnv
from streaming_flow_policy.pusht.dp_state_notebook.network import ConditionalUnet1D
OBS_HORIZON, OBS_DIM, ACT_DIM, PRED_HORIZON, ACT_HORIZON = 2, 5, 2, 16, 8
N_DDPM, MAX_STEPS, SUCCESS_TH = 100, 300, 0.95
GOAL = np.array([256.0, 256.0])
def sample_init(seed):
rng = np.random.default_rng(seed)
D = rng.uniform(27.0, 63.0)
ang = np.deg2rad(rng.uniform(42.0, 48.0))
sd = np.array([-np.sin(ang), np.cos(ang)])
block_xy = GOAL + D * sd + rng.uniform(-3.0, 3.0, 2)
agent_xy = GOAL - D * sd + rng.uniform(-15.0, 15.0, 2)
return np.array([agent_xy[0], agent_xy[1], block_xy[0], block_xy[1], ang], dtype=np.float64)
def cog_of_T(bx, by, bth):
c, s = np.cos(bth), np.sin(bth)
return np.array([bx - s * 45, by + c * 45])
def geom_mode(states_array, cap=50):
cap = min(cap, len(states_array))
if cap < 5: return -1
agent = states_array[:cap, 0:2]
block = states_array[:cap, 2:4]
th_ = states_array[:cap, 4]
cogs = np.stack([cog_of_T(block[t, 0], block[t, 1], th_[t]) for t in range(cap)])
rel = agent - cogs
ang = np.unwrap(np.arctan2(rel[:, 1], rel[:, 0]))
netw = np.degrees(ang - ang[0])
if abs(netw.min()) > abs(netw.max()) + 30.0: return 1
if abs(netw.max()) > abs(netw.min()) + 30.0: return 0
return -1
def main():
device = args.device
stats = np.load(args.stats)
obs_min, obs_max = stats["obs_min"], stats["obs_max"]
act_min = stats["act_min"] if "act_min" in stats.files else stats["action_min"]
act_max = stats["act_max"] if "act_max" in stats.files else stats["action_max"]
print(f"Loading ckpt: {args.ckpt}", flush=True)
model = ConditionalUnet1D(input_dim=ACT_DIM, global_cond_dim=OBS_DIM * OBS_HORIZON)
sd = torch.load(args.ckpt, map_location=device, weights_only=False)
if isinstance(sd, dict):
if "ema_state_dict" in sd: sd = sd["ema_state_dict"]
elif "state_dict" in sd: sd = sd["state_dict"]
model.load_state_dict(sd); model.to(device).eval()
for pp in model.parameters(): pp.requires_grad_(False)
sched = DDPMScheduler(num_train_timesteps=N_DDPM, beta_schedule="squaredcos_cap_v2",
clip_sample=True, prediction_type="epsilon")
def nz(x): return (x - obs_min) / (obs_max - obs_min) * 2 - 1
seeds = list(range(args.seed_base, args.seed_base + args.n_seeds))
t0 = time.time(); n_succ = 0; modes = {0: 0, 1: 0, -1: 0}
first_succ_states = None; first_succ_init = None
for seed in seeds:
init = sample_init(seed)
env = PushTEnv(reset_to_state=init.copy()); env.seed(seed); o, _ = env.reset()
hist = [o.copy()] * OBS_HORIZON
states = [o.copy()]; rewards = []
chunk = None; step = 0
while step < MAX_STEPS:
if step % ACT_HORIZON == 0:
oc_np = np.stack(hist[-OBS_HORIZON:])
oc = torch.from_numpy(nz(oc_np).astype(np.float32)).to(device).flatten().unsqueeze(0)
a = torch.randn((1, PRED_HORIZON, ACT_DIM),
generator=torch.Generator(device=device).manual_seed(int(seed + step*1000 + 7919)),
device=device)
sched.set_timesteps(N_DDPM)
with torch.no_grad():
for k in sched.timesteps:
eps = model(sample=a, timestep=k, global_cond=oc)
a = sched.step(model_output=eps, timestep=k, sample=a).prev_sample
chunk = ((a.cpu().numpy()[0] + 1) / 2 * (act_max - act_min) + act_min)
cs = step % ACT_HORIZON
act = chunk[cs + OBS_HORIZON - 1]
o, r, term, trunc, _ = env.step(act); hist.append(o.copy()); states.append(o.copy()); rewards.append(float(r))
step += 1
if term or trunc or r >= SUCCESS_TH: break
success = (max(rewards) if rewards else 0.0) >= SUCCESS_TH
states_arr = np.asarray(states)
mode = geom_mode(states_arr); modes[mode] += 1
if success: n_succ += 1
if success and first_succ_states is None and args.save_mp4:
first_succ_states = states_arr; first_succ_init = init
mode_str = {0: "LEFT", 1: "RIGHT", -1: "mixed"}[mode]
print(f" seed {seed}: {'OK' if success else 'FAIL'} mode={mode_str} steps={len(states_arr)-1}", flush=True)
print(f"\n=== {args.n_seeds} seeds ({time.time()-t0:.1f}s) ===")
print(f"SR @ coverage>={SUCCESS_TH}: {n_succ}/{args.n_seeds} = {100*n_succ/args.n_seeds:.1f}%")
print(f"mode dist: LEFT={modes[0]} RIGHT={modes[1]} mixed={modes[-1]}")
if args.save_mp4 and first_succ_states is not None:
import imageio.v2 as iio
env = PushTEnv(reset_to_state=first_succ_init.copy(), render_size=512)
env.seed(0); env.reset()
frames = []
for s in first_succ_states:
env._set_state(s); frames.append(env.render("rgb_array").copy())
iio.mimsave(args.save_mp4, frames, fps=10)
print(f"saved {args.save_mp4} ({len(frames)} frames)")
if __name__ == "__main__":
main()