meta-rl-dsa-solver / scripts /test_training_config.py
Dishaaa25's picture
Prepare overnight training run
a8426bb
from __future__ import annotations
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from training.train_grpo import build_timing_summary, build_training_config, resolve_precision_policy
from verifier.metrics import compute_episode_reward
class FakeCuda:
def __init__(self, available: bool, bf16_supported: bool) -> None:
self._available = available
self._bf16_supported = bf16_supported
def is_available(self) -> bool:
return self._available
def is_bf16_supported(self) -> bool:
return self._bf16_supported
class FakeTorch:
bfloat16 = "bfloat16"
float16 = "float16"
float32 = "float32"
def __init__(self, available: bool, bf16_supported: bool) -> None:
self.cuda = FakeCuda(available=available, bf16_supported=bf16_supported)
def main() -> None:
l4_config = build_training_config("l4")
smoke_config = build_training_config("smoke")
overnight_config = build_training_config("overnight")
assert l4_config.model_name == "Qwen/Qwen2.5-3B-Instruct"
assert l4_config.load_in_4bit is True
assert l4_config.gradient_checkpointing is True
assert l4_config.num_generations == 4
assert smoke_config.load_in_4bit is False
assert smoke_config.gradient_checkpointing is False
assert smoke_config.upload_checkpoints_to_hub is False
assert overnight_config.load_in_4bit is True
assert overnight_config.gradient_checkpointing is True
assert overnight_config.num_generations == 4
assert overnight_config.max_steps == 950
assert overnight_config.save_steps == 50
assert overnight_config.save_total_limit == 3
assert overnight_config.upload_checkpoints_to_hub is True
bf16_policy = resolve_precision_policy(l4_config, FakeTorch(available=True, bf16_supported=True))
assert bf16_policy["precision_mode"] == "bf16"
assert bf16_policy["load_in_4bit"] is True
fp16_policy = resolve_precision_policy(l4_config, FakeTorch(available=True, bf16_supported=False))
assert fp16_policy["precision_mode"] == "fp16"
assert fp16_policy["load_in_4bit"] is True
cpu_policy = resolve_precision_policy(smoke_config, FakeTorch(available=False, bf16_supported=False))
assert cpu_policy["precision_mode"] == "fp32"
assert cpu_policy["load_in_4bit"] is False
timing_summary = build_timing_summary(
config=smoke_config,
wall_clock_seconds=180.0,
completed_steps=6,
train_episode_count=12,
)
assert timing_summary["wall_clock_minutes"] == 3.0
assert timing_summary["avg_seconds_per_step"] == 30.0
assert timing_summary["episodes_per_hour"] == 240.0
reward, components = compute_episode_reward(
pass_rate=1.0,
step_number=1,
execution_status="completed",
previous_pass_rate=0.0,
done=False,
efficiency_score=0.94,
optimization_target_met=False,
)
assert reward == 0.94
assert components["progress_delta"] == 1.0
print("Training config smoke tests passed")
if __name__ == "__main__":
main()