machineai-compiler-optimizer / transfer_agent.py
sosonsong's picture
Upload transfer_agent.py with huggingface_hub
fb177fd verified
"""
transfer_agent.py โ€” Architecture-aware Transfer Learning for LoopUnrollEnv
- arch๋ณ„๋กœ x86 / arm64 ๋“ฑ์„ ์„ ํƒ์ ์œผ๋กœ ์ง€์›
- Backbone: ๊ธฐ์กด {arch}_base ๋ชจ๋ธ์˜ ์ผ๋ถ€ ๋ ˆ์ด์–ด๋ฅผ ๋ฐฑ๋ณธ์œผ๋กœ ์‚ฌ์šฉ
- Adapter: ์ƒˆ ํ™˜๊ฒฝ(๋˜๋Š” ์ƒˆ CPU)์— ๋งž๊ฒŒ ์†Œํ˜• ๋ ˆ์ด์–ด๋งŒ ์žฌํ•™์Šต
"""
import os
import glob
import sys
import argparse
import numpy as np
import torch
import torch.nn as nn
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from compiler_env import LoopUnrollEnv
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ์œ ํ‹ธ: ๊ฒฝ๋กœ ๋ฐ ๊ธฐ๋ณธ ์„ค์ •
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
PROJECT_ROOT = os.path.expanduser("~/projects/machineai")
MODELS_DIR = os.path.join(PROJECT_ROOT, "models")
BENCH_DIR = os.path.join(PROJECT_ROOT, "benchmarks")
def get_model_paths(arch: str):
"""
์•„ํ‚คํ…์ฒ˜๋ณ„ ๊ธฐ๋ณธ ๋ชจ๋ธ/์ „์ด ๋ชจ๋ธ ๊ฒฝ๋กœ ์ƒ์„ฑ
- base: models/model_{arch}_base.zip
- transfer: models/model_{arch}_transfer.zip
"""
base = os.path.join(MODELS_DIR, f"model_{arch}_base.zip")
transfer = os.path.join(MODELS_DIR, f"model_{arch}_transfer.zip")
return base, transfer
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Backbone ๊ฐ€์ค‘์น˜ ์ถ”์ถœ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def extract_backbone_weights(model_path: str) -> dict:
"""
๊ธฐ์กด PPO ๋ชจ๋ธ์—์„œ mlp_extractor์˜ ์ผ๋ถ€ ๋ ˆ์ด์–ด๋ฅผ ๋ฐฑ๋ณธ์œผ๋กœ ์ถ”์ถœ
- ํ˜„์žฌ๋Š” policy_net์˜ ์ฒซ ๋‘ ๋ ˆ์ด์–ด๋ฅผ ๋ฐฑ๋ณธ์œผ๋กœ ์‚ฌ์šฉ
"""
print(f"[Backbone] ๋กœ๋“œ: {model_path}")
model = PPO.load(model_path)
state_dict = model.policy.state_dict()
backbone = {}
for k, v in state_dict.items():
if "mlp_extractor.policy_net.0" in k or "mlp_extractor.policy_net.2" in k:
backbone[k] = v.clone()
print(f"[Backbone] ์ถ”์ถœ ๋ ˆ์ด์–ด:")
for k in backbone.keys():
print(f" - {k}")
return backbone
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Transfer PPO ๋นŒ๋”
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def build_transfer_model(env, backbone_weights: dict | None, freeze_backbone: bool = True):
"""
Backbone ๋™๊ฒฐ + Adapter ๋ ˆ์ด์–ด ์ถ”๊ฐ€ํ•œ PPO ๋ชจ๋ธ ๊ตฌ์„ฑ
- backbone_weights๊ฐ€ None์ด๋ฉด ์ˆœ์ˆ˜ ์ƒˆ ๋ชจ๋ธ๋กœ ์‹œ์ž‘
"""
print("[Model] Transfer PPO ์ƒ์„ฑ ์ค‘...")
model = PPO(
policy="MlpPolicy",
env=env,
learning_rate=1e-4, # ์ „์ดํ•™์Šต์€ ๋‚ฎ์€ lr
n_steps=256,
batch_size=64,
n_epochs=10,
gamma=0.99,
verbose=1,
policy_kwargs=dict(net_arch=[64, 64, 32]), # +32 adapter layer
)
# ๋ฐฑ๋ณธ ๊ฐ€์ค‘์น˜ ์ฃผ์ž…
if backbone_weights is not None:
print("[Model] Backbone ๊ฐ€์ค‘์น˜ ์ฃผ์ž…...")
state_dict = model.policy.state_dict()
injected, skipped = 0, 0
for k, v in backbone_weights.items():
if k in state_dict and state_dict[k].shape == v.shape:
state_dict[k] = v
injected += 1
print(f" โœ” ์ฃผ์ž…: {k}")
else:
skipped += 1
print(f" โœ— ์Šคํ‚ต: {k} (shape mismatch or not found)")
model.policy.load_state_dict(state_dict)
print(f"[Model] ์ฃผ์ž… ์™„๋ฃŒ: {injected}๊ฐœ, ์Šคํ‚ต: {skipped}๊ฐœ")
else:
print("[Model] Backbone ์—†์ด ์ƒˆ ๋ชจ๋ธ๋กœ ์‹œ์ž‘")
# ๋ฐฑ๋ณธ ๋™๊ฒฐ
if freeze_backbone and backbone_weights is not None:
print("[Model] Backbone ํŒŒ๋ผ๋ฏธํ„ฐ ๋™๊ฒฐ...")
for name, param in model.policy.named_parameters():
if "mlp_extractor.policy_net.0" in name or "mlp_extractor.policy_net.2" in name:
param.requires_grad = False
print(f" ๐Ÿ”’ ๋™๊ฒฐ: {name}")
trainable = sum(p.numel() for p in model.policy.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.policy.parameters())
print(f"\n[Model] ํŒŒ๋ผ๋ฏธํ„ฐ: {trainable}/{total} ํ•™์Šต๊ฐ€๋Šฅ ({trainable/total*100:.1f}%)")
return model
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๋ฉ”์ธ ์ „์ดํ•™์Šต ์‹คํ–‰
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def main():
parser = argparse.ArgumentParser(description="Architecture-aware transfer learning for LoopUnrollEnv")
parser.add_argument("--arch", type=str, default="x86", help="ํƒ€๊ฒŸ ์•„ํ‚คํ…์ฒ˜ (์˜ˆ: x86, arm64)")
parser.add_argument("--timesteps", type=int, default=2000, help="์ „์ดํ•™์Šต ์Šคํ… ์ˆ˜")
parser.add_argument("--load-base", action="store_true", help="๊ธฐ์กด base ๋ชจ๋ธ์—์„œ backbone์„ ๋กœ๋“œํ• ์ง€ ์—ฌ๋ถ€")
parser.add_argument("--base-path", type=str, default="", help="์ง์ ‘ base ๋ชจ๋ธ ๊ฒฝ๋กœ ์ง€์ • (์˜ต์…˜)")
parser.add_argument("--out-path", type=str, default="", help="์ „์ด ๊ฒฐ๊ณผ ์ €์žฅ ๊ฒฝ๋กœ ์ง์ ‘ ์ง€์ • (์˜ต์…˜)")
parser.add_argument("--repeat-runs", type=int, default=3, help="์‹คํ–‰ ์‹œ๊ฐ„ ์ธก์ • ๋ฐ˜๋ณต ํšŸ์ˆ˜")
parser.add_argument("--freeze-backbone", action="store_true", help="Backbone ๋ ˆ์ด์–ด๋ฅผ ๋™๊ฒฐํ• ์ง€ ์—ฌ๋ถ€")
parser.add_argument("--clang-bin", type=str, default="", help="์‚ฌ์šฉํ•  clang ๋ฐ”์ด๋„ˆ๋ฆฌ (๋น„์šฐ๋ฉด ๊ธฐ๋ณธ๊ฐ’)")
parser.add_argument("--opt-bin", type=str, default="", help="์‚ฌ์šฉํ•  opt ๋ฐ”์ด๋„ˆ๋ฆฌ (๋น„์šฐ๋ฉด ๊ธฐ๋ณธ๊ฐ’)")
parser.add_argument("--source-files", type=str, nargs="+", default=[], help="ํ•™์Šต์— ์‚ฌ์šฉํ•  ์†Œ์Šค ํŒŒ์ผ ๋ชฉ๋ก")
args = parser.parse_args()
arch = args.arch
print(f"[Config] arch={arch}")
# ๊ฒฝ๋กœ ์„ค์ •
os.makedirs(MODELS_DIR, exist_ok=True)
default_base, default_transfer = get_model_paths(arch)
base_model_path = args.base_path or default_base
transfer_model_path = args.out_path or default_transfer
print(f"[Config] base_model_path = {base_model_path}")
print(f"[Config] transfer_model_path= {transfer_model_path}")
# ํ•™์Šต ๋Œ€์ƒ ์†Œ์Šค ํŒŒ์ผ
if args.source_files:
source_files = [os.path.abspath(f) for f in args.source_files]
else:
source_files = sorted(glob.glob(os.path.join(BENCH_DIR, "*.c")))
print(f"[Data] ํ•™์Šต ๋Œ€์ƒ: {source_files}")
# Backbone ๋กœ๋“œ (์˜ต์…˜)
backbone = None
if args.load_base:
if not os.path.exists(base_model_path):
raise FileNotFoundError(f"Base ๋ชจ๋ธ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {base_model_path}")
backbone = extract_backbone_weights(base_model_path)
else:
print("[Backbone] base ๋ชจ๋ธ ๋กœ๋“œ ์ƒ๋žต (์ˆœ์ˆ˜ ์ƒˆ ๋ชจ๋ธ๋กœ ์‹œ์ž‘)")
# Env ์ƒ์„ฑ ํ•จ์ˆ˜
def make_env():
return LoopUnrollEnv(
source_files=source_files,
repeat_runs=args.repeat_runs,
arch=arch,
clang_bin=args.clang_bin or None,
opt_bin=args.opt_bin or None,
)
vec_env = make_vec_env(make_env, n_envs=1)
# Transfer ๋ชจ๋ธ ๋นŒ๋“œ
print("\n=== Transfer ๋ชจ๋ธ ๋นŒ๋“œ ===")
model = build_transfer_model(vec_env, backbone, freeze_backbone=args.freeze_backbone)
# ํ•™์Šต
print(f"\n=== Adapter ํ•™์Šต ({args.timesteps} ์Šคํ…) ===")
model.learn(total_timesteps=args.timesteps, progress_bar=True)
# ์ €์žฅ
model.save(transfer_model_path.replace(".zip", ""))
print(f"\n์ €์žฅ ์™„๋ฃŒ: {transfer_model_path}")
if __name__ == "__main__":
main()