File size: 3,164 Bytes
7f2d9e7
 
 
 
 
 
 
 
 
6c58fca
7f2d9e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8426bb
7f2d9e7
 
 
 
 
 
 
 
a8426bb
 
 
 
 
 
 
 
 
7f2d9e7
 
 
 
 
 
 
 
 
 
 
 
 
6c58fca
 
 
 
 
 
 
 
 
 
7f2d9e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import annotations

import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from training.train_grpo import build_timing_summary, build_training_config, resolve_precision_policy
from verifier.metrics import compute_episode_reward


class FakeCuda:
    def __init__(self, available: bool, bf16_supported: bool) -> None:
        self._available = available
        self._bf16_supported = bf16_supported

    def is_available(self) -> bool:
        return self._available

    def is_bf16_supported(self) -> bool:
        return self._bf16_supported


class FakeTorch:
    bfloat16 = "bfloat16"
    float16 = "float16"
    float32 = "float32"

    def __init__(self, available: bool, bf16_supported: bool) -> None:
        self.cuda = FakeCuda(available=available, bf16_supported=bf16_supported)


def main() -> None:
    l4_config = build_training_config("l4")
    smoke_config = build_training_config("smoke")
    overnight_config = build_training_config("overnight")

    assert l4_config.model_name == "Qwen/Qwen2.5-3B-Instruct"
    assert l4_config.load_in_4bit is True
    assert l4_config.gradient_checkpointing is True
    assert l4_config.num_generations == 4

    assert smoke_config.load_in_4bit is False
    assert smoke_config.gradient_checkpointing is False
    assert smoke_config.upload_checkpoints_to_hub is False

    assert overnight_config.load_in_4bit is True
    assert overnight_config.gradient_checkpointing is True
    assert overnight_config.num_generations == 4
    assert overnight_config.max_steps == 950
    assert overnight_config.save_steps == 50
    assert overnight_config.save_total_limit == 3
    assert overnight_config.upload_checkpoints_to_hub is True

    bf16_policy = resolve_precision_policy(l4_config, FakeTorch(available=True, bf16_supported=True))
    assert bf16_policy["precision_mode"] == "bf16"
    assert bf16_policy["load_in_4bit"] is True

    fp16_policy = resolve_precision_policy(l4_config, FakeTorch(available=True, bf16_supported=False))
    assert fp16_policy["precision_mode"] == "fp16"
    assert fp16_policy["load_in_4bit"] is True

    cpu_policy = resolve_precision_policy(smoke_config, FakeTorch(available=False, bf16_supported=False))
    assert cpu_policy["precision_mode"] == "fp32"
    assert cpu_policy["load_in_4bit"] is False

    timing_summary = build_timing_summary(
        config=smoke_config,
        wall_clock_seconds=180.0,
        completed_steps=6,
        train_episode_count=12,
    )
    assert timing_summary["wall_clock_minutes"] == 3.0
    assert timing_summary["avg_seconds_per_step"] == 30.0
    assert timing_summary["episodes_per_hour"] == 240.0

    reward, components = compute_episode_reward(
        pass_rate=1.0,
        step_number=1,
        execution_status="completed",
        previous_pass_rate=0.0,
        done=False,
        efficiency_score=0.94,
        optimization_target_met=False,
    )
    assert reward == 0.94
    assert components["progress_delta"] == 1.0
    print("Training config smoke tests passed")


if __name__ == "__main__":
    main()