File size: 1,638 Bytes
2fdf3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch

paths = {
    0: "outputs_v2/Llama-3.2-3B-Instruct-MATH-500-tokens10-lr0.03-sigma0.1-sigdecay0.99-steps0-all-rewardall_tokens-conf/logistics.pt",
    1: "outputs_v2/Llama-3.2-3B-Instruct-MATH-500-tokens10-lr0.03-sigma0.1-sigdecay0.99-steps1-all-rewardall_tokens-conf/logistics.pt",
    2: "outputs_v2/Llama-3.2-3B-Instruct-MATH-500-tokens10-lr0.03-sigma0.1-sigdecay0.99-steps2-all-rewardall_tokens-conf/logistics.pt",
    4: "outputs_v2/Llama-3.2-3B-Instruct-MATH-500-tokens10-lr0.03-sigma0.1-sigdecay0.99-steps4-all-rewardall_tokens-conf/logistics.pt",
    10: "outputs_v2/Llama-3.2-3B-Instruct-MATH-500-tokens10-lr0.03-sigma0.1-sigdecay0.99-steps10-all-rewardall_tokens-conf/logistics.pt",
}

results = {}
for step, path in paths.items():
    ckpt = torch.load(path)
    step_dict = {}
    for entry in ckpt["entries"]:
        step_dict[entry["data_idx"]] = {
            "is_correct": entry["is_correct"],
            "answer": entry["answer"],
            "init_reward": entry.get("init_reward", None),
            "best_reward": entry["best_reward"],
            "best_reward_step": entry["best_reward_step"],
        }
    results[step] = step_dict

all_ids = sorted(set().union(*[set(v.keys()) for v in results.values()]))

for idx in all_ids:
    row = [f"idx={idx}"]
    init_r = results[0].get(idx, {}).get("init_reward", None)
    row.append(f"init_r={init_r}")
    row.append(f"s10_best_step={results[10].get(idx, {}).get('best_reward_step', None)}")
    for step in [0, 1, 2, 4, 10]:
        ok = results[step].get(idx, {}).get("is_correct", None)
        row.append(f"s{step}={ok}")
    print(" | ".join(row))