sosonsong commited on
Commit
fb177fd
ยท
verified ยท
1 Parent(s): 17782d8

Upload transfer_agent.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. transfer_agent.py +195 -0
transfer_agent.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ transfer_agent.py โ€” Architecture-aware Transfer Learning for LoopUnrollEnv
3
+
4
+ - arch๋ณ„๋กœ x86 / arm64 ๋“ฑ์„ ์„ ํƒ์ ์œผ๋กœ ์ง€์›
5
+ - Backbone: ๊ธฐ์กด {arch}_base ๋ชจ๋ธ์˜ ์ผ๋ถ€ ๋ ˆ์ด์–ด๋ฅผ ๋ฐฑ๋ณธ์œผ๋กœ ์‚ฌ์šฉ
6
+ - Adapter: ์ƒˆ ํ™˜๊ฒฝ(๋˜๋Š” ์ƒˆ CPU)์— ๋งž๊ฒŒ ์†Œํ˜• ๋ ˆ์ด์–ด๋งŒ ์žฌํ•™์Šต
7
+ """
8
+
9
+ import os
10
+ import glob
11
+ import sys
12
+ import argparse
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import gymnasium as gym
18
+ from stable_baselines3 import PPO
19
+ from stable_baselines3.common.env_util import make_vec_env
20
+
21
+ from compiler_env import LoopUnrollEnv
22
+
23
+
24
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
25
+ # ์œ ํ‹ธ: ๊ฒฝ๋กœ ๋ฐ ๊ธฐ๋ณธ ์„ค์ •
26
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
27
+
28
+ PROJECT_ROOT = os.path.expanduser("~/projects/machineai")
29
+ MODELS_DIR = os.path.join(PROJECT_ROOT, "models")
30
+ BENCH_DIR = os.path.join(PROJECT_ROOT, "benchmarks")
31
+
32
+
33
+ def get_model_paths(arch: str):
34
+ """
35
+ ์•„ํ‚คํ…์ฒ˜๋ณ„ ๊ธฐ๋ณธ ๋ชจ๋ธ/์ „์ด ๋ชจ๋ธ ๊ฒฝ๋กœ ์ƒ์„ฑ
36
+ - base: models/model_{arch}_base.zip
37
+ - transfer: models/model_{arch}_transfer.zip
38
+ """
39
+ base = os.path.join(MODELS_DIR, f"model_{arch}_base.zip")
40
+ transfer = os.path.join(MODELS_DIR, f"model_{arch}_transfer.zip")
41
+ return base, transfer
42
+
43
+
44
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
45
+ # Backbone ๊ฐ€์ค‘์น˜ ์ถ”์ถœ
46
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
47
+
48
+ def extract_backbone_weights(model_path: str) -> dict:
49
+ """
50
+ ๊ธฐ์กด PPO ๋ชจ๋ธ์—์„œ mlp_extractor์˜ ์ผ๋ถ€ ๋ ˆ์ด์–ด๋ฅผ ๋ฐฑ๋ณธ์œผ๋กœ ์ถ”์ถœ
51
+ - ํ˜„์žฌ๋Š” policy_net์˜ ์ฒซ ๋‘ ๋ ˆ์ด์–ด๋ฅผ ๋ฐฑ๋ณธ์œผ๋กœ ์‚ฌ์šฉ
52
+ """
53
+ print(f"[Backbone] ๋กœ๋“œ: {model_path}")
54
+ model = PPO.load(model_path)
55
+ state_dict = model.policy.state_dict()
56
+ backbone = {}
57
+
58
+ for k, v in state_dict.items():
59
+ if "mlp_extractor.policy_net.0" in k or "mlp_extractor.policy_net.2" in k:
60
+ backbone[k] = v.clone()
61
+
62
+ print(f"[Backbone] ์ถ”์ถœ ๋ ˆ์ด์–ด:")
63
+ for k in backbone.keys():
64
+ print(f" - {k}")
65
+ return backbone
66
+
67
+
68
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
69
+ # Transfer PPO ๋นŒ๋”
70
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
71
+
72
+ def build_transfer_model(env, backbone_weights: dict | None, freeze_backbone: bool = True):
73
+ """
74
+ Backbone ๋™๊ฒฐ + Adapter ๋ ˆ์ด์–ด ์ถ”๊ฐ€ํ•œ PPO ๋ชจ๋ธ ๊ตฌ์„ฑ
75
+ - backbone_weights๊ฐ€ None์ด๋ฉด ์ˆœ์ˆ˜ ์ƒˆ ๋ชจ๋ธ๋กœ ์‹œ์ž‘
76
+ """
77
+ print("[Model] Transfer PPO ์ƒ์„ฑ ์ค‘...")
78
+ model = PPO(
79
+ policy="MlpPolicy",
80
+ env=env,
81
+ learning_rate=1e-4, # ์ „์ดํ•™์Šต์€ ๋‚ฎ์€ lr
82
+ n_steps=256,
83
+ batch_size=64,
84
+ n_epochs=10,
85
+ gamma=0.99,
86
+ verbose=1,
87
+ policy_kwargs=dict(net_arch=[64, 64, 32]), # +32 adapter layer
88
+ )
89
+
90
+ # ๋ฐฑ๋ณธ ๊ฐ€์ค‘์น˜ ์ฃผ์ž…
91
+ if backbone_weights is not None:
92
+ print("[Model] Backbone ๊ฐ€์ค‘์น˜ ์ฃผ์ž…...")
93
+ state_dict = model.policy.state_dict()
94
+ injected, skipped = 0, 0
95
+ for k, v in backbone_weights.items():
96
+ if k in state_dict and state_dict[k].shape == v.shape:
97
+ state_dict[k] = v
98
+ injected += 1
99
+ print(f" โœ” ์ฃผ์ž…: {k}")
100
+ else:
101
+ skipped += 1
102
+ print(f" โœ— ์Šคํ‚ต: {k} (shape mismatch or not found)")
103
+ model.policy.load_state_dict(state_dict)
104
+ print(f"[Model] ์ฃผ์ž… ์™„๋ฃŒ: {injected}๊ฐœ, ์Šคํ‚ต: {skipped}๊ฐœ")
105
+ else:
106
+ print("[Model] Backbone ์—†์ด ์ƒˆ ๋ชจ๋ธ๋กœ ์‹œ์ž‘")
107
+
108
+ # ๋ฐฑ๋ณธ ๋™๊ฒฐ
109
+ if freeze_backbone and backbone_weights is not None:
110
+ print("[Model] Backbone ํŒŒ๋ผ๋ฏธํ„ฐ ๋™๊ฒฐ...")
111
+ for name, param in model.policy.named_parameters():
112
+ if "mlp_extractor.policy_net.0" in name or "mlp_extractor.policy_net.2" in name:
113
+ param.requires_grad = False
114
+ print(f" ๐Ÿ”’ ๋™๊ฒฐ: {name}")
115
+
116
+ trainable = sum(p.numel() for p in model.policy.parameters() if p.requires_grad)
117
+ total = sum(p.numel() for p in model.policy.parameters())
118
+ print(f"\n[Model] ํŒŒ๋ผ๋ฏธํ„ฐ: {trainable}/{total} ํ•™์Šต๊ฐ€๋Šฅ ({trainable/total*100:.1f}%)")
119
+ return model
120
+
121
+
122
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€๏ฟฝ๏ฟฝ๏ฟฝโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
123
+ # ๋ฉ”์ธ ์ „์ดํ•™์Šต ์‹คํ–‰
124
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
125
+
126
+ def main():
127
+ parser = argparse.ArgumentParser(description="Architecture-aware transfer learning for LoopUnrollEnv")
128
+ parser.add_argument("--arch", type=str, default="x86", help="ํƒ€๊ฒŸ ์•„ํ‚คํ…์ฒ˜ (์˜ˆ: x86, arm64)")
129
+ parser.add_argument("--timesteps", type=int, default=2000, help="์ „์ดํ•™์Šต ์Šคํ… ์ˆ˜")
130
+ parser.add_argument("--load-base", action="store_true", help="๊ธฐ์กด base ๋ชจ๋ธ์—์„œ backbone์„ ๋กœ๋“œํ• ์ง€ ์—ฌ๋ถ€")
131
+ parser.add_argument("--base-path", type=str, default="", help="์ง์ ‘ base ๋ชจ๋ธ ๊ฒฝ๋กœ ์ง€์ • (์˜ต์…˜)")
132
+ parser.add_argument("--out-path", type=str, default="", help="์ „์ด ๊ฒฐ๊ณผ ์ €์žฅ ๊ฒฝ๋กœ ์ง์ ‘ ์ง€์ • (์˜ต์…˜)")
133
+ parser.add_argument("--repeat-runs", type=int, default=3, help="์‹คํ–‰ ์‹œ๊ฐ„ ์ธก์ • ๋ฐ˜๋ณต ํšŸ์ˆ˜")
134
+ parser.add_argument("--freeze-backbone", action="store_true", help="Backbone ๋ ˆ์ด์–ด๋ฅผ ๋™๊ฒฐํ• ์ง€ ์—ฌ๋ถ€")
135
+ parser.add_argument("--clang-bin", type=str, default="", help="์‚ฌ์šฉํ•  clang ๋ฐ”์ด๋„ˆ๋ฆฌ (๋น„์šฐ๋ฉด ๊ธฐ๋ณธ๊ฐ’)")
136
+ parser.add_argument("--opt-bin", type=str, default="", help="์‚ฌ์šฉํ•  opt ๋ฐ”์ด๋„ˆ๋ฆฌ (๋น„์šฐ๋ฉด ๊ธฐ๋ณธ๊ฐ’)")
137
+ parser.add_argument("--source-files", type=str, nargs="+", default=[], help="ํ•™์Šต์— ์‚ฌ์šฉํ•  ์†Œ์Šค ํŒŒ์ผ ๋ชฉ๋ก")
138
+ args = parser.parse_args()
139
+
140
+ arch = args.arch
141
+ print(f"[Config] arch={arch}")
142
+
143
+ # ๊ฒฝ๋กœ ์„ค์ •
144
+ os.makedirs(MODELS_DIR, exist_ok=True)
145
+ default_base, default_transfer = get_model_paths(arch)
146
+
147
+ base_model_path = args.base_path or default_base
148
+ transfer_model_path = args.out_path or default_transfer
149
+
150
+ print(f"[Config] base_model_path = {base_model_path}")
151
+ print(f"[Config] transfer_model_path= {transfer_model_path}")
152
+
153
+ # ํ•™์Šต ๋Œ€์ƒ ์†Œ์Šค ํŒŒ์ผ
154
+ if args.source_files:
155
+ source_files = [os.path.abspath(f) for f in args.source_files]
156
+ else:
157
+ source_files = sorted(glob.glob(os.path.join(BENCH_DIR, "*.c")))
158
+ print(f"[Data] ํ•™์Šต ๋Œ€์ƒ: {source_files}")
159
+
160
+ # Backbone ๋กœ๋“œ (์˜ต์…˜)
161
+ backbone = None
162
+ if args.load_base:
163
+ if not os.path.exists(base_model_path):
164
+ raise FileNotFoundError(f"Base ๋ชจ๋ธ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {base_model_path}")
165
+ backbone = extract_backbone_weights(base_model_path)
166
+ else:
167
+ print("[Backbone] base ๋ชจ๋ธ ๋กœ๋“œ ์ƒ๋žต (์ˆœ์ˆ˜ ์ƒˆ ๋ชจ๋ธ๋กœ ์‹œ์ž‘)")
168
+
169
+ # Env ์ƒ์„ฑ ํ•จ์ˆ˜
170
+ def make_env():
171
+ return LoopUnrollEnv(
172
+ source_files=source_files,
173
+ repeat_runs=args.repeat_runs,
174
+ arch=arch,
175
+ clang_bin=args.clang_bin or None,
176
+ opt_bin=args.opt_bin or None,
177
+ )
178
+
179
+ vec_env = make_vec_env(make_env, n_envs=1)
180
+
181
+ # Transfer ๋ชจ๋ธ ๋นŒ๋“œ
182
+ print("\n=== Transfer ๋ชจ๋ธ ๋นŒ๋“œ ===")
183
+ model = build_transfer_model(vec_env, backbone, freeze_backbone=args.freeze_backbone)
184
+
185
+ # ํ•™์Šต
186
+ print(f"\n=== Adapter ํ•™์Šต ({args.timesteps} ์Šคํ…) ===")
187
+ model.learn(total_timesteps=args.timesteps, progress_bar=True)
188
+
189
+ # ์ €์žฅ
190
+ model.save(transfer_model_path.replace(".zip", ""))
191
+ print(f"\n์ €์žฅ ์™„๋ฃŒ: {transfer_model_path}")
192
+
193
+
194
+ if __name__ == "__main__":
195
+ main()