sosonsong commited on
Commit
17782d8
ยท
verified ยท
1 Parent(s): 0214fd4

Upload meta_agent.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. meta_agent.py +91 -0
meta_agent.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ meta_agent.py โ€” Reptile Meta-Learning (eval_speedup ์ œ๊ฑฐ ๋ฒ„์ „)
3
+ """
4
+ import os, copy, argparse
5
+ import numpy as np
6
+ import torch
7
+ from stable_baselines3 import PPO
8
+ from stable_baselines3.common.env_util import make_vec_env
9
+ from compiler_env import LoopUnrollEnv
10
+
11
+ PROJECT_ROOT = os.path.expanduser("~/projects/machineai")
12
+ MODELS_DIR = os.path.join(PROJECT_ROOT, "models")
13
+ BENCH_DIR = os.path.join(PROJECT_ROOT, "benchmarks")
14
+
15
+ def get_params(model):
16
+ return copy.deepcopy(model.policy.state_dict())
17
+
18
+ def set_params(model, params):
19
+ model.policy.load_state_dict(copy.deepcopy(params))
20
+
21
+ def make_model(arch, source_files, base_params=None):
22
+ env = make_vec_env(lambda: LoopUnrollEnv(
23
+ arch=arch, source_files=source_files), n_envs=1)
24
+ model = PPO("MlpPolicy", env, verbose=0,
25
+ learning_rate=3e-4, n_steps=64, batch_size=32)
26
+ if base_params:
27
+ set_params(model, base_params)
28
+ return model
29
+
30
+ def main():
31
+ ap = argparse.ArgumentParser()
32
+ ap.add_argument("--meta-base", default="models/x86v2_base.zip")
33
+ ap.add_argument("--arch", default="aarch64-linux-gnu")
34
+ ap.add_argument("--outer-iters", type=int, default=3)
35
+ ap.add_argument("--inner-steps", type=int, default=256)
36
+ ap.add_argument("--adapt-steps", type=int, default=100)
37
+ ap.add_argument("--out-path", default="models/meta_init.zip")
38
+ args = ap.parse_args()
39
+
40
+ bench_files = [
41
+ os.path.join(BENCH_DIR, f) for f in [
42
+ "loop_heavy_arm64v2.c", "nested_arm64v2.c", "matmul_arm64.c"
43
+ ] if os.path.exists(os.path.join(BENCH_DIR, f))
44
+ ]
45
+ tasks = [[f] for f in bench_files]
46
+ print(f"[Meta] ํƒœ์Šคํฌ ์ˆ˜: {len(tasks)}")
47
+
48
+ # ๋ฉ”ํƒ€ ์ดˆ๊ธฐํ™” ๋กœ๋“œ
49
+ init_env = make_vec_env(lambda: LoopUnrollEnv(arch="x86_64"), n_envs=1)
50
+ base_model = PPO.load(args.meta_base, env=init_env)
51
+ meta_params = get_params(base_model)
52
+ print(f"[Meta] ๋กœ๋“œ ์™„๋ฃŒ: {args.meta_base}")
53
+
54
+ # Reptile outer loop
55
+ print(f"\n=== Meta-Train ({args.outer_iters} outer iters x {args.inner_steps} inner steps) ===")
56
+ for outer_i in range(args.outer_iters):
57
+ adapted_list = []
58
+ for task_files in tasks:
59
+ m = make_model(args.arch, task_files, base_params=meta_params)
60
+ m.learn(total_timesteps=args.inner_steps, reset_num_timesteps=True)
61
+ rew = np.mean([ep["r"] for ep in m.ep_info_buffer]) if m.ep_info_buffer else 0.0
62
+ adapted_list.append((get_params(m), rew))
63
+ print(f" ํƒœ์Šคํฌ {os.path.basename(task_files[0])}: reward={rew:.1f}")
64
+
65
+ # Reptile ์—…๋ฐ์ดํŠธ: meta = meta + 0.3 * (ํ‰๊ท ์ ์‘ - meta)
66
+ keys = meta_params.keys()
67
+ avg = {k: torch.stack([p[k].float() for p,_ in adapted_list]).mean(0) for k in keys}
68
+ meta_lr = 0.3
69
+ for k in keys:
70
+ meta_params[k] = meta_params[k].float() + meta_lr * (avg[k] - meta_params[k].float())
71
+
72
+ avg_rew = np.mean([r for _,r in adapted_list])
73
+ print(f"[Outer {outer_i+1}/{args.outer_iters}] ํ‰๊ท  reward: {avg_rew:.1f}")
74
+
75
+ # ๋ฉ”ํƒ€ ๋ชจ๋ธ ์ €์žฅ
76
+ set_params(base_model, meta_params)
77
+ base_model.save(args.out_path)
78
+ print(f"\n[Meta] ์ €์žฅ: {args.out_path}")
79
+
80
+ # ๋น ๋ฅธ ์ ์‘ ๊ฒ€์ฆ (benchmark.py ํ™œ์šฉ)
81
+ print(f"\n=== Fast Adapt ({args.adapt_steps}์Šคํ…) ===")
82
+ m = make_model(args.arch, bench_files, base_params=meta_params)
83
+ m.learn(total_timesteps=args.adapt_steps, reset_num_timesteps=True)
84
+ m.save("models/meta_adapted.zip")
85
+ print("์ ์‘ ๋ชจ๋ธ ์ €์žฅ: models/meta_adapted.zip")
86
+ print("\n์ด์ œ benchmark.py๋กœ ์„ฑ๋Šฅ ํ™•์ธ ์ค‘...")
87
+ os.system(f"python3 benchmark.py --arch {args.arch} --model models/meta_adapted.zip "
88
+ f"--source-files {' '.join(bench_files)}")
89
+
90
+ if __name__ == "__main__":
91
+ main()