File size: 3,794 Bytes
17782d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
meta_agent.py โ€” Reptile Meta-Learning (eval_speedup ์ œ๊ฑฐ ๋ฒ„์ „)
"""
import os, copy, argparse
import numpy as np
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from compiler_env import LoopUnrollEnv

PROJECT_ROOT = os.path.expanduser("~/projects/machineai")
MODELS_DIR   = os.path.join(PROJECT_ROOT, "models")
BENCH_DIR    = os.path.join(PROJECT_ROOT, "benchmarks")

def get_params(model):
    return copy.deepcopy(model.policy.state_dict())

def set_params(model, params):
    model.policy.load_state_dict(copy.deepcopy(params))

def make_model(arch, source_files, base_params=None):
    env = make_vec_env(lambda: LoopUnrollEnv(
        arch=arch, source_files=source_files), n_envs=1)
    model = PPO("MlpPolicy", env, verbose=0,
                learning_rate=3e-4, n_steps=64, batch_size=32)
    if base_params:
        set_params(model, base_params)
    return model

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--meta-base",   default="models/x86v2_base.zip")
    ap.add_argument("--arch",        default="aarch64-linux-gnu")
    ap.add_argument("--outer-iters", type=int, default=3)
    ap.add_argument("--inner-steps", type=int, default=256)
    ap.add_argument("--adapt-steps", type=int, default=100)
    ap.add_argument("--out-path",    default="models/meta_init.zip")
    args = ap.parse_args()

    bench_files = [
        os.path.join(BENCH_DIR, f) for f in [
            "loop_heavy_arm64v2.c", "nested_arm64v2.c", "matmul_arm64.c"
        ] if os.path.exists(os.path.join(BENCH_DIR, f))
    ]
    tasks = [[f] for f in bench_files]
    print(f"[Meta] ํƒœ์Šคํฌ ์ˆ˜: {len(tasks)}")

    # ๋ฉ”ํƒ€ ์ดˆ๊ธฐํ™” ๋กœ๋“œ
    init_env = make_vec_env(lambda: LoopUnrollEnv(arch="x86_64"), n_envs=1)
    base_model = PPO.load(args.meta_base, env=init_env)
    meta_params = get_params(base_model)
    print(f"[Meta] ๋กœ๋“œ ์™„๋ฃŒ: {args.meta_base}")

    # Reptile outer loop
    print(f"\n=== Meta-Train ({args.outer_iters} outer iters x {args.inner_steps} inner steps) ===")
    for outer_i in range(args.outer_iters):
        adapted_list = []
        for task_files in tasks:
            m = make_model(args.arch, task_files, base_params=meta_params)
            m.learn(total_timesteps=args.inner_steps, reset_num_timesteps=True)
            rew = np.mean([ep["r"] for ep in m.ep_info_buffer]) if m.ep_info_buffer else 0.0
            adapted_list.append((get_params(m), rew))
            print(f"  ํƒœ์Šคํฌ {os.path.basename(task_files[0])}: reward={rew:.1f}")

        # Reptile ์—…๋ฐ์ดํŠธ: meta = meta + 0.3 * (ํ‰๊ท ์ ์‘ - meta)
        keys = meta_params.keys()
        avg = {k: torch.stack([p[k].float() for p,_ in adapted_list]).mean(0) for k in keys}
        meta_lr = 0.3
        for k in keys:
            meta_params[k] = meta_params[k].float() + meta_lr * (avg[k] - meta_params[k].float())

        avg_rew = np.mean([r for _,r in adapted_list])
        print(f"[Outer {outer_i+1}/{args.outer_iters}] ํ‰๊ท  reward: {avg_rew:.1f}")

    # ๋ฉ”ํƒ€ ๋ชจ๋ธ ์ €์žฅ
    set_params(base_model, meta_params)
    base_model.save(args.out_path)
    print(f"\n[Meta] ์ €์žฅ: {args.out_path}")

    # ๋น ๋ฅธ ์ ์‘ ๊ฒ€์ฆ (benchmark.py ํ™œ์šฉ)
    print(f"\n=== Fast Adapt ({args.adapt_steps}์Šคํ…) ===")
    m = make_model(args.arch, bench_files, base_params=meta_params)
    m.learn(total_timesteps=args.adapt_steps, reset_num_timesteps=True)
    m.save("models/meta_adapted.zip")
    print("์ ์‘ ๋ชจ๋ธ ์ €์žฅ: models/meta_adapted.zip")
    print("\n์ด์ œ benchmark.py๋กœ ์„ฑ๋Šฅ ํ™•์ธ ์ค‘...")
    os.system(f"python3 benchmark.py --arch {args.arch} --model models/meta_adapted.zip "
              f"--source-files {' '.join(bench_files)}")

if __name__ == "__main__":
    main()