Upload meta_agent.py with huggingface_hub
Browse files- meta_agent.py +91 -0
meta_agent.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
meta_agent.py โ Reptile Meta-Learning (eval_speedup ์ ๊ฑฐ ๋ฒ์ )
|
| 3 |
+
"""
|
| 4 |
+
import os, copy, argparse
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from stable_baselines3 import PPO
|
| 8 |
+
from stable_baselines3.common.env_util import make_vec_env
|
| 9 |
+
from compiler_env import LoopUnrollEnv
|
| 10 |
+
|
| 11 |
+
PROJECT_ROOT = os.path.expanduser("~/projects/machineai")
|
| 12 |
+
MODELS_DIR = os.path.join(PROJECT_ROOT, "models")
|
| 13 |
+
BENCH_DIR = os.path.join(PROJECT_ROOT, "benchmarks")
|
| 14 |
+
|
| 15 |
+
def get_params(model):
|
| 16 |
+
return copy.deepcopy(model.policy.state_dict())
|
| 17 |
+
|
| 18 |
+
def set_params(model, params):
|
| 19 |
+
model.policy.load_state_dict(copy.deepcopy(params))
|
| 20 |
+
|
| 21 |
+
def make_model(arch, source_files, base_params=None):
|
| 22 |
+
env = make_vec_env(lambda: LoopUnrollEnv(
|
| 23 |
+
arch=arch, source_files=source_files), n_envs=1)
|
| 24 |
+
model = PPO("MlpPolicy", env, verbose=0,
|
| 25 |
+
learning_rate=3e-4, n_steps=64, batch_size=32)
|
| 26 |
+
if base_params:
|
| 27 |
+
set_params(model, base_params)
|
| 28 |
+
return model
|
| 29 |
+
|
| 30 |
+
def main():
|
| 31 |
+
ap = argparse.ArgumentParser()
|
| 32 |
+
ap.add_argument("--meta-base", default="models/x86v2_base.zip")
|
| 33 |
+
ap.add_argument("--arch", default="aarch64-linux-gnu")
|
| 34 |
+
ap.add_argument("--outer-iters", type=int, default=3)
|
| 35 |
+
ap.add_argument("--inner-steps", type=int, default=256)
|
| 36 |
+
ap.add_argument("--adapt-steps", type=int, default=100)
|
| 37 |
+
ap.add_argument("--out-path", default="models/meta_init.zip")
|
| 38 |
+
args = ap.parse_args()
|
| 39 |
+
|
| 40 |
+
bench_files = [
|
| 41 |
+
os.path.join(BENCH_DIR, f) for f in [
|
| 42 |
+
"loop_heavy_arm64v2.c", "nested_arm64v2.c", "matmul_arm64.c"
|
| 43 |
+
] if os.path.exists(os.path.join(BENCH_DIR, f))
|
| 44 |
+
]
|
| 45 |
+
tasks = [[f] for f in bench_files]
|
| 46 |
+
print(f"[Meta] ํ์คํฌ ์: {len(tasks)}")
|
| 47 |
+
|
| 48 |
+
# ๋ฉํ ์ด๊ธฐํ ๋ก๋
|
| 49 |
+
init_env = make_vec_env(lambda: LoopUnrollEnv(arch="x86_64"), n_envs=1)
|
| 50 |
+
base_model = PPO.load(args.meta_base, env=init_env)
|
| 51 |
+
meta_params = get_params(base_model)
|
| 52 |
+
print(f"[Meta] ๋ก๋ ์๋ฃ: {args.meta_base}")
|
| 53 |
+
|
| 54 |
+
# Reptile outer loop
|
| 55 |
+
print(f"\n=== Meta-Train ({args.outer_iters} outer iters x {args.inner_steps} inner steps) ===")
|
| 56 |
+
for outer_i in range(args.outer_iters):
|
| 57 |
+
adapted_list = []
|
| 58 |
+
for task_files in tasks:
|
| 59 |
+
m = make_model(args.arch, task_files, base_params=meta_params)
|
| 60 |
+
m.learn(total_timesteps=args.inner_steps, reset_num_timesteps=True)
|
| 61 |
+
rew = np.mean([ep["r"] for ep in m.ep_info_buffer]) if m.ep_info_buffer else 0.0
|
| 62 |
+
adapted_list.append((get_params(m), rew))
|
| 63 |
+
print(f" ํ์คํฌ {os.path.basename(task_files[0])}: reward={rew:.1f}")
|
| 64 |
+
|
| 65 |
+
# Reptile ์
๋ฐ์ดํธ: meta = meta + 0.3 * (ํ๊ท ์ ์ - meta)
|
| 66 |
+
keys = meta_params.keys()
|
| 67 |
+
avg = {k: torch.stack([p[k].float() for p,_ in adapted_list]).mean(0) for k in keys}
|
| 68 |
+
meta_lr = 0.3
|
| 69 |
+
for k in keys:
|
| 70 |
+
meta_params[k] = meta_params[k].float() + meta_lr * (avg[k] - meta_params[k].float())
|
| 71 |
+
|
| 72 |
+
avg_rew = np.mean([r for _,r in adapted_list])
|
| 73 |
+
print(f"[Outer {outer_i+1}/{args.outer_iters}] ํ๊ท reward: {avg_rew:.1f}")
|
| 74 |
+
|
| 75 |
+
# ๋ฉํ ๋ชจ๋ธ ์ ์ฅ
|
| 76 |
+
set_params(base_model, meta_params)
|
| 77 |
+
base_model.save(args.out_path)
|
| 78 |
+
print(f"\n[Meta] ์ ์ฅ: {args.out_path}")
|
| 79 |
+
|
| 80 |
+
# ๋น ๋ฅธ ์ ์ ๊ฒ์ฆ (benchmark.py ํ์ฉ)
|
| 81 |
+
print(f"\n=== Fast Adapt ({args.adapt_steps}์คํ
) ===")
|
| 82 |
+
m = make_model(args.arch, bench_files, base_params=meta_params)
|
| 83 |
+
m.learn(total_timesteps=args.adapt_steps, reset_num_timesteps=True)
|
| 84 |
+
m.save("models/meta_adapted.zip")
|
| 85 |
+
print("์ ์ ๋ชจ๋ธ ์ ์ฅ: models/meta_adapted.zip")
|
| 86 |
+
print("\n์ด์ benchmark.py๋ก ์ฑ๋ฅ ํ์ธ ์ค...")
|
| 87 |
+
os.system(f"python3 benchmark.py --arch {args.arch} --model models/meta_adapted.zip "
|
| 88 |
+
f"--source-files {' '.join(bench_files)}")
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
main()
|