File size: 7,302 Bytes
fad16c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb7b148
fad16c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a70fbb9
fad16c9
 
 
 
a70fbb9
 
fad16c9
 
 
 
 
 
 
 
 
 
fb7b148
 
fad16c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e2a258
 
 
fad16c9
8f2eab9
fad16c9
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from __future__ import annotations

import argparse
import os
import sys
from pathlib import Path

import torch
import trl.extras.profiling as trl_profiling
from datasets import Dataset
from peft import LoraConfig
from transformers import AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
from trl.chat_template_utils import qwen3_chat_template, qwen3_schema

import wandb

ROOT_DIR = Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
    sys.path.insert(0, str(ROOT_DIR))

from rollout_eval import run_baseline_episode, run_heuristic_episode, run_random_episode
from scripts.neural_tuner import NeuralTunerOpenEnv
from server.neural_tuner_env_environment import NeuralTunerEnvironment
from training_utils import load_jsonl, mean, split_scenarios, write_json


def build_prompt_row(system_prompt: str, scenario: dict) -> list[dict]:
    scenario_text = (
        f"Scenario ID: {scenario['id']}\n"
        f"Model: {scenario['model_id']}\n"
        f"Difficulty: {scenario['difficulty']}\n"
        f"Hint: {scenario['scenario_hint']}\n"
        f"Target behavior: {scenario['target_behavior']}\n"
        f"Notes: {scenario['notes']}\n"
    )
    return [{"role": "user", "content": system_prompt + "\n\n" + scenario_text}]


def reward_func(environments, **kwargs) -> list[float]:
    rewards = []
    for env in environments:
        if not env.done:
            env.submit()
        rewards.append(float(env.reward))
    return rewards


def evaluate_split(eval_rows: list[dict], n_random: int = 10) -> dict:
    env = NeuralTunerEnvironment()
    random_rewards = []
    baseline_rewards = []
    heuristic_rewards = []

    for idx, row in enumerate(eval_rows):
        model_id = row["model_id"]
        difficulty = row["difficulty"]
        for seed in range(n_random):
            random_rewards.append(run_random_episode(env, model_id, difficulty, seed=idx * 100 + seed).final_reward)
        baseline_rewards.append(run_baseline_episode(env, model_id, difficulty).final_reward)
        heuristic_rewards.append(run_heuristic_episode(env, model_id, difficulty).final_reward)

    random_mean = mean(random_rewards)
    oracle_ceiling = mean(heuristic_rewards) if heuristic_rewards else 0.0
    return {
        "eval_random_mean": random_mean,
        "eval_baseline_mean": mean(baseline_rewards),
        "eval_heuristic_mean": oracle_ceiling,
        "eval_oracle_ceiling": oracle_ceiling,
    }


def main() -> int:
    parser = argparse.ArgumentParser(description="Run reproducible NeuralTuner training + held-out eval.")
    parser.add_argument("--model-path", default=".hf_cache/Qwen--Qwen2.5-1.5B-Instruct")
    parser.add_argument("--scenarios", default="data/mock_data/neural_tuner_train_scenarios.jsonl")
    parser.add_argument("--output-dir", default="outputs/neural_tuner_grpo_tiny_lora")
    parser.add_argument("--max-steps", type=int, default=20)
    parser.add_argument("--train-fraction", type=float, default=0.8)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--skip-train", action="store_true")
    args = parser.parse_args()

    wandb_key = os.getenv("WANDB_API_KEY")
    trl_profiling.wandb = wandb
    if wandb_key:
        wandb.login(key=wandb_key)
        wandb.init(
            project=os.getenv("WANDB_PROJECT", "round_2"),
            name=f"neural_tuner_grpo_script_{args.seed}",
            resume="never",
            reinit=True,
        )

    model_name = str(Path(args.model_path).resolve())
    scenarios = load_jsonl(Path(args.scenarios))
    train_rows, eval_rows = split_scenarios(scenarios, train_fraction=args.train_fraction, seed=args.seed)

    system_prompt = (
        "You are a hardware optimization agent for NeuralTuner.\n"
        "Use tools to profile layers, apply quantization/pruning, benchmark, and submit."
    )
    train_dataset = Dataset.from_dict(
        {
            "prompt": [build_prompt_row(system_prompt, s) for s in train_rows],
            "model_id": [s["model_id"] for s in train_rows],
            "difficulty": [s["difficulty"] for s in train_rows],
            "scenario_id": [s["id"] for s in train_rows],
        }
    )

    processing_class = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, local_files_only=True)
    processing_class.chat_template = qwen3_chat_template
    processing_class.response_schema = qwen3_schema

    peft_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules="all-linear",
        task_type="CAUSAL_LM",
    )
    NeuralTunerOpenEnv.scenario_schedule = train_rows
    NeuralTunerOpenEnv.schedule_idx = 0

    grpo_config = GRPOConfig(
        output_dir=args.output_dir,
        run_name=f"neural_tuner_grpo_script_{args.seed}",
        max_steps=args.max_steps,
        learning_rate=8e-7,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
        warmup_steps=2,
        max_completion_length=256,
        num_generations=8,
        generation_batch_size=8,
        temperature=0.9,
        top_p=0.95,
        generation_kwargs={"do_sample": True, "renormalize_logits": True, "remove_invalid_values": True},
        mask_truncated_completions=False,
        max_tool_calling_iterations=20,
        log_completions=False,
        gradient_checkpointing=True,
        torch_empty_cache_steps=1,
        logging_steps=1,
        report_to="wandb" if wandb_key else "none",
        # Force CPU for scripted sweeps to avoid MPS/offload meta-device backward errors.
        use_cpu=True,
        bf16=False,
        fp16=False,
        use_vllm=False,
    )

    train_metrics = {"train_skipped": bool(args.skip_train)}
    if not args.skip_train:
        trainer = GRPOTrainer(
            model=model_name,
            reward_funcs=reward_func,
            train_dataset=train_dataset,
            args=grpo_config,
            processing_class=processing_class,
            peft_config=peft_config,
            environment_factory=NeuralTunerOpenEnv,
        )
        trainer_stats = trainer.train()
        trainer.save_model(args.output_dir)
        train_metrics.update(trainer_stats.metrics)
        reward_logs = [x["reward"] for x in trainer.state.log_history if "reward" in x]
        if reward_logs:
            train_metrics["training_reward_mean"] = mean(reward_logs)
            train_metrics["training_reward_last5_mean"] = mean(reward_logs[-5:])

    eval_metrics = evaluate_split(eval_rows, n_random=5)
    merged = {
        **train_metrics,
        **eval_metrics,
        "train_size": len(train_rows),
        "eval_size": len(eval_rows),
    }
    # Keep lift metrics explicit to avoid mixing train/eval interpretations.
    merged["lift_eval_baseline_vs_random"] = merged["eval_baseline_mean"] - merged["eval_random_mean"]
    merged["lift_eval_heuristic_vs_random"] = merged["eval_heuristic_mean"] - merged["eval_random_mean"]
    if "training_reward_mean" in merged:
        merged["lift_train_reward_vs_eval_random"] = merged["training_reward_mean"] - merged["eval_random_mean"]
    write_json(Path("artifacts/training/train_eval_script_metrics.json"), merged)

    if wandb_key:
        wandb.log(merged)
        wandb.finish()
    return 0


if __name__ == "__main__":
    raise SystemExit(main())