Neural-Tuner / scripts /run_training_eval.py
Mohammed-Altaf's picture
updated training eval
a70fbb9
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
import torch
import trl.extras.profiling as trl_profiling
from datasets import Dataset
from peft import LoraConfig
from transformers import AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
from trl.chat_template_utils import qwen3_chat_template, qwen3_schema
import wandb
ROOT_DIR = Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.insert(0, str(ROOT_DIR))
from rollout_eval import run_baseline_episode, run_heuristic_episode, run_random_episode
from scripts.neural_tuner import NeuralTunerOpenEnv
from server.neural_tuner_env_environment import NeuralTunerEnvironment
from training_utils import load_jsonl, mean, split_scenarios, write_json
def build_prompt_row(system_prompt: str, scenario: dict) -> list[dict]:
scenario_text = (
f"Scenario ID: {scenario['id']}\n"
f"Model: {scenario['model_id']}\n"
f"Difficulty: {scenario['difficulty']}\n"
f"Hint: {scenario['scenario_hint']}\n"
f"Target behavior: {scenario['target_behavior']}\n"
f"Notes: {scenario['notes']}\n"
)
return [{"role": "user", "content": system_prompt + "\n\n" + scenario_text}]
def reward_func(environments, **kwargs) -> list[float]:
rewards = []
for env in environments:
if not env.done:
env.submit()
rewards.append(float(env.reward))
return rewards
def evaluate_split(eval_rows: list[dict], n_random: int = 10) -> dict:
env = NeuralTunerEnvironment()
random_rewards = []
baseline_rewards = []
heuristic_rewards = []
for idx, row in enumerate(eval_rows):
model_id = row["model_id"]
difficulty = row["difficulty"]
for seed in range(n_random):
random_rewards.append(run_random_episode(env, model_id, difficulty, seed=idx * 100 + seed).final_reward)
baseline_rewards.append(run_baseline_episode(env, model_id, difficulty).final_reward)
heuristic_rewards.append(run_heuristic_episode(env, model_id, difficulty).final_reward)
random_mean = mean(random_rewards)
oracle_ceiling = mean(heuristic_rewards) if heuristic_rewards else 0.0
return {
"eval_random_mean": random_mean,
"eval_baseline_mean": mean(baseline_rewards),
"eval_heuristic_mean": oracle_ceiling,
"eval_oracle_ceiling": oracle_ceiling,
}
def main() -> int:
parser = argparse.ArgumentParser(description="Run reproducible NeuralTuner training + held-out eval.")
parser.add_argument("--model-path", default=".hf_cache/Qwen--Qwen2.5-1.5B-Instruct")
parser.add_argument("--scenarios", default="data/mock_data/neural_tuner_train_scenarios.jsonl")
parser.add_argument("--output-dir", default="outputs/neural_tuner_grpo_tiny_lora")
parser.add_argument("--max-steps", type=int, default=20)
parser.add_argument("--train-fraction", type=float, default=0.8)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--skip-train", action="store_true")
args = parser.parse_args()
wandb_key = os.getenv("WANDB_API_KEY")
trl_profiling.wandb = wandb
if wandb_key:
wandb.login(key=wandb_key)
wandb.init(
project=os.getenv("WANDB_PROJECT", "round_2"),
name=f"neural_tuner_grpo_script_{args.seed}",
resume="never",
reinit=True,
)
model_name = str(Path(args.model_path).resolve())
scenarios = load_jsonl(Path(args.scenarios))
train_rows, eval_rows = split_scenarios(scenarios, train_fraction=args.train_fraction, seed=args.seed)
system_prompt = (
"You are a hardware optimization agent for NeuralTuner.\n"
"Use tools to profile layers, apply quantization/pruning, benchmark, and submit."
)
train_dataset = Dataset.from_dict(
{
"prompt": [build_prompt_row(system_prompt, s) for s in train_rows],
"model_id": [s["model_id"] for s in train_rows],
"difficulty": [s["difficulty"] for s in train_rows],
"scenario_id": [s["id"] for s in train_rows],
}
)
processing_class = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, local_files_only=True)
processing_class.chat_template = qwen3_chat_template
processing_class.response_schema = qwen3_schema
peft_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.05,
target_modules="all-linear",
task_type="CAUSAL_LM",
)
NeuralTunerOpenEnv.scenario_schedule = train_rows
NeuralTunerOpenEnv.schedule_idx = 0
grpo_config = GRPOConfig(
output_dir=args.output_dir,
run_name=f"neural_tuner_grpo_script_{args.seed}",
max_steps=args.max_steps,
learning_rate=8e-7,
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
warmup_steps=2,
max_completion_length=256,
num_generations=8,
generation_batch_size=8,
temperature=0.9,
top_p=0.95,
generation_kwargs={"do_sample": True, "renormalize_logits": True, "remove_invalid_values": True},
mask_truncated_completions=False,
max_tool_calling_iterations=20,
log_completions=False,
gradient_checkpointing=True,
torch_empty_cache_steps=1,
logging_steps=1,
report_to="wandb" if wandb_key else "none",
# Force CPU for scripted sweeps to avoid MPS/offload meta-device backward errors.
use_cpu=True,
bf16=False,
fp16=False,
use_vllm=False,
)
train_metrics = {"train_skipped": bool(args.skip_train)}
if not args.skip_train:
trainer = GRPOTrainer(
model=model_name,
reward_funcs=reward_func,
train_dataset=train_dataset,
args=grpo_config,
processing_class=processing_class,
peft_config=peft_config,
environment_factory=NeuralTunerOpenEnv,
)
trainer_stats = trainer.train()
trainer.save_model(args.output_dir)
train_metrics.update(trainer_stats.metrics)
reward_logs = [x["reward"] for x in trainer.state.log_history if "reward" in x]
if reward_logs:
train_metrics["training_reward_mean"] = mean(reward_logs)
train_metrics["training_reward_last5_mean"] = mean(reward_logs[-5:])
eval_metrics = evaluate_split(eval_rows, n_random=5)
merged = {
**train_metrics,
**eval_metrics,
"train_size": len(train_rows),
"eval_size": len(eval_rows),
}
# Keep lift metrics explicit to avoid mixing train/eval interpretations.
merged["lift_eval_baseline_vs_random"] = merged["eval_baseline_mean"] - merged["eval_random_mean"]
merged["lift_eval_heuristic_vs_random"] = merged["eval_heuristic_mean"] - merged["eval_random_mean"]
if "training_reward_mean" in merged:
merged["lift_train_reward_vs_eval_random"] = merged["training_reward_mean"] - merged["eval_random_mean"]
write_json(Path("artifacts/training/train_eval_script_metrics.json"), merged)
if wandb_key:
wandb.log(merged)
wandb.finish()
return 0
if __name__ == "__main__":
raise SystemExit(main())