| | """ |
| | transfer_agent.py โ Architecture-aware Transfer Learning for LoopUnrollEnv |
| | |
| | - arch๋ณ๋ก x86 / arm64 ๋ฑ์ ์ ํ์ ์ผ๋ก ์ง์ |
| | - Backbone: ๊ธฐ์กด {arch}_base ๋ชจ๋ธ์ ์ผ๋ถ ๋ ์ด์ด๋ฅผ ๋ฐฑ๋ณธ์ผ๋ก ์ฌ์ฉ |
| | - Adapter: ์ ํ๊ฒฝ(๋๋ ์ CPU)์ ๋ง๊ฒ ์ํ ๋ ์ด์ด๋ง ์ฌํ์ต |
| | """ |
| |
|
| | import os |
| | import glob |
| | import sys |
| | import argparse |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import gymnasium as gym |
| | from stable_baselines3 import PPO |
| | from stable_baselines3.common.env_util import make_vec_env |
| |
|
| | from compiler_env import LoopUnrollEnv |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | PROJECT_ROOT = os.path.expanduser("~/projects/machineai") |
| | MODELS_DIR = os.path.join(PROJECT_ROOT, "models") |
| | BENCH_DIR = os.path.join(PROJECT_ROOT, "benchmarks") |
| |
|
| |
|
| | def get_model_paths(arch: str): |
| | """ |
| | ์ํคํ
์ฒ๋ณ ๊ธฐ๋ณธ ๋ชจ๋ธ/์ ์ด ๋ชจ๋ธ ๊ฒฝ๋ก ์์ฑ |
| | - base: models/model_{arch}_base.zip |
| | - transfer: models/model_{arch}_transfer.zip |
| | """ |
| | base = os.path.join(MODELS_DIR, f"model_{arch}_base.zip") |
| | transfer = os.path.join(MODELS_DIR, f"model_{arch}_transfer.zip") |
| | return base, transfer |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def extract_backbone_weights(model_path: str) -> dict: |
| | """ |
| | ๊ธฐ์กด PPO ๋ชจ๋ธ์์ mlp_extractor์ ์ผ๋ถ ๋ ์ด์ด๋ฅผ ๋ฐฑ๋ณธ์ผ๋ก ์ถ์ถ |
| | - ํ์ฌ๋ policy_net์ ์ฒซ ๋ ๋ ์ด์ด๋ฅผ ๋ฐฑ๋ณธ์ผ๋ก ์ฌ์ฉ |
| | """ |
| | print(f"[Backbone] ๋ก๋: {model_path}") |
| | model = PPO.load(model_path) |
| | state_dict = model.policy.state_dict() |
| | backbone = {} |
| |
|
| | for k, v in state_dict.items(): |
| | if "mlp_extractor.policy_net.0" in k or "mlp_extractor.policy_net.2" in k: |
| | backbone[k] = v.clone() |
| |
|
| | print(f"[Backbone] ์ถ์ถ ๋ ์ด์ด:") |
| | for k in backbone.keys(): |
| | print(f" - {k}") |
| | return backbone |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def build_transfer_model(env, backbone_weights: dict | None, freeze_backbone: bool = True): |
| | """ |
| | Backbone ๋๊ฒฐ + Adapter ๋ ์ด์ด ์ถ๊ฐํ PPO ๋ชจ๋ธ ๊ตฌ์ฑ |
| | - backbone_weights๊ฐ None์ด๋ฉด ์์ ์ ๋ชจ๋ธ๋ก ์์ |
| | """ |
| | print("[Model] Transfer PPO ์์ฑ ์ค...") |
| | model = PPO( |
| | policy="MlpPolicy", |
| | env=env, |
| | learning_rate=1e-4, |
| | n_steps=256, |
| | batch_size=64, |
| | n_epochs=10, |
| | gamma=0.99, |
| | verbose=1, |
| | policy_kwargs=dict(net_arch=[64, 64, 32]), |
| | ) |
| |
|
| | |
| | if backbone_weights is not None: |
| | print("[Model] Backbone ๊ฐ์ค์น ์ฃผ์
...") |
| | state_dict = model.policy.state_dict() |
| | injected, skipped = 0, 0 |
| | for k, v in backbone_weights.items(): |
| | if k in state_dict and state_dict[k].shape == v.shape: |
| | state_dict[k] = v |
| | injected += 1 |
| | print(f" โ ์ฃผ์
: {k}") |
| | else: |
| | skipped += 1 |
| | print(f" โ ์คํต: {k} (shape mismatch or not found)") |
| | model.policy.load_state_dict(state_dict) |
| | print(f"[Model] ์ฃผ์
์๋ฃ: {injected}๊ฐ, ์คํต: {skipped}๊ฐ") |
| | else: |
| | print("[Model] Backbone ์์ด ์ ๋ชจ๋ธ๋ก ์์") |
| |
|
| | |
| | if freeze_backbone and backbone_weights is not None: |
| | print("[Model] Backbone ํ๋ผ๋ฏธํฐ ๋๊ฒฐ...") |
| | for name, param in model.policy.named_parameters(): |
| | if "mlp_extractor.policy_net.0" in name or "mlp_extractor.policy_net.2" in name: |
| | param.requires_grad = False |
| | print(f" ๐ ๋๊ฒฐ: {name}") |
| |
|
| | trainable = sum(p.numel() for p in model.policy.parameters() if p.requires_grad) |
| | total = sum(p.numel() for p in model.policy.parameters()) |
| | print(f"\n[Model] ํ๋ผ๋ฏธํฐ: {trainable}/{total} ํ์ต๊ฐ๋ฅ ({trainable/total*100:.1f}%)") |
| | return model |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Architecture-aware transfer learning for LoopUnrollEnv") |
| | parser.add_argument("--arch", type=str, default="x86", help="ํ๊ฒ ์ํคํ
์ฒ (์: x86, arm64)") |
| | parser.add_argument("--timesteps", type=int, default=2000, help="์ ์ดํ์ต ์คํ
์") |
| | parser.add_argument("--load-base", action="store_true", help="๊ธฐ์กด base ๋ชจ๋ธ์์ backbone์ ๋ก๋ํ ์ง ์ฌ๋ถ") |
| | parser.add_argument("--base-path", type=str, default="", help="์ง์ base ๋ชจ๋ธ ๊ฒฝ๋ก ์ง์ (์ต์
)") |
| | parser.add_argument("--out-path", type=str, default="", help="์ ์ด ๊ฒฐ๊ณผ ์ ์ฅ ๊ฒฝ๋ก ์ง์ ์ง์ (์ต์
)") |
| | parser.add_argument("--repeat-runs", type=int, default=3, help="์คํ ์๊ฐ ์ธก์ ๋ฐ๋ณต ํ์") |
| | parser.add_argument("--freeze-backbone", action="store_true", help="Backbone ๋ ์ด์ด๋ฅผ ๋๊ฒฐํ ์ง ์ฌ๋ถ") |
| | parser.add_argument("--clang-bin", type=str, default="", help="์ฌ์ฉํ clang ๋ฐ์ด๋๋ฆฌ (๋น์ฐ๋ฉด ๊ธฐ๋ณธ๊ฐ)") |
| | parser.add_argument("--opt-bin", type=str, default="", help="์ฌ์ฉํ opt ๋ฐ์ด๋๋ฆฌ (๋น์ฐ๋ฉด ๊ธฐ๋ณธ๊ฐ)") |
| | parser.add_argument("--source-files", type=str, nargs="+", default=[], help="ํ์ต์ ์ฌ์ฉํ ์์ค ํ์ผ ๋ชฉ๋ก") |
| | args = parser.parse_args() |
| |
|
| | arch = args.arch |
| | print(f"[Config] arch={arch}") |
| |
|
| | |
| | os.makedirs(MODELS_DIR, exist_ok=True) |
| | default_base, default_transfer = get_model_paths(arch) |
| |
|
| | base_model_path = args.base_path or default_base |
| | transfer_model_path = args.out_path or default_transfer |
| |
|
| | print(f"[Config] base_model_path = {base_model_path}") |
| | print(f"[Config] transfer_model_path= {transfer_model_path}") |
| |
|
| | |
| | if args.source_files: |
| | source_files = [os.path.abspath(f) for f in args.source_files] |
| | else: |
| | source_files = sorted(glob.glob(os.path.join(BENCH_DIR, "*.c"))) |
| | print(f"[Data] ํ์ต ๋์: {source_files}") |
| |
|
| | |
| | backbone = None |
| | if args.load_base: |
| | if not os.path.exists(base_model_path): |
| | raise FileNotFoundError(f"Base ๋ชจ๋ธ์ ์ฐพ์ ์ ์์ต๋๋ค: {base_model_path}") |
| | backbone = extract_backbone_weights(base_model_path) |
| | else: |
| | print("[Backbone] base ๋ชจ๋ธ ๋ก๋ ์๋ต (์์ ์ ๋ชจ๋ธ๋ก ์์)") |
| |
|
| | |
| | def make_env(): |
| | return LoopUnrollEnv( |
| | source_files=source_files, |
| | repeat_runs=args.repeat_runs, |
| | arch=arch, |
| | clang_bin=args.clang_bin or None, |
| | opt_bin=args.opt_bin or None, |
| | ) |
| |
|
| | vec_env = make_vec_env(make_env, n_envs=1) |
| |
|
| | |
| | print("\n=== Transfer ๋ชจ๋ธ ๋น๋ ===") |
| | model = build_transfer_model(vec_env, backbone, freeze_backbone=args.freeze_backbone) |
| |
|
| | |
| | print(f"\n=== Adapter ํ์ต ({args.timesteps} ์คํ
) ===") |
| | model.learn(total_timesteps=args.timesteps, progress_bar=True) |
| |
|
| | |
| | model.save(transfer_model_path.replace(".zip", "")) |
| | print(f"\n์ ์ฅ ์๋ฃ: {transfer_model_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|