import torch paths = { 0: "outputs_v2/Llama-3.2-3B-Instruct-MATH-500-tokens10-lr0.03-sigma0.1-sigdecay0.99-steps0-all-rewardall_tokens-conf/logistics.pt", 1: "outputs_v2/Llama-3.2-3B-Instruct-MATH-500-tokens10-lr0.03-sigma0.1-sigdecay0.99-steps1-all-rewardall_tokens-conf/logistics.pt", 2: "outputs_v2/Llama-3.2-3B-Instruct-MATH-500-tokens10-lr0.03-sigma0.1-sigdecay0.99-steps2-all-rewardall_tokens-conf/logistics.pt", 4: "outputs_v2/Llama-3.2-3B-Instruct-MATH-500-tokens10-lr0.03-sigma0.1-sigdecay0.99-steps4-all-rewardall_tokens-conf/logistics.pt", 10: "outputs_v2/Llama-3.2-3B-Instruct-MATH-500-tokens10-lr0.03-sigma0.1-sigdecay0.99-steps10-all-rewardall_tokens-conf/logistics.pt", } results = {} for step, path in paths.items(): ckpt = torch.load(path) step_dict = {} for entry in ckpt["entries"]: step_dict[entry["data_idx"]] = { "is_correct": entry["is_correct"], "answer": entry["answer"], "init_reward": entry.get("init_reward", None), "best_reward": entry["best_reward"], "best_reward_step": entry["best_reward_step"], } results[step] = step_dict all_ids = sorted(set().union(*[set(v.keys()) for v in results.values()])) for idx in all_ids: row = [f"idx={idx}"] init_r = results[0].get(idx, {}).get("init_reward", None) row.append(f"init_r={init_r}") row.append(f"s10_best_step={results[10].get(idx, {}).get('best_reward_step', None)}") for step in [0, 1, 2, 4, 10]: ok = results[step].get(idx, {}).get("is_correct", None) row.append(f"s{step}={ok}") print(" | ".join(row))