File size: 3,970 Bytes
6e7ed91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from __future__ import annotations

import json
import sys
from pathlib import Path
from tempfile import TemporaryDirectory

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from training.trace_logging import TraceArtifactLogger


def main() -> None:
    with TemporaryDirectory() as tmpdir:
        output_dir = Path(tmpdir)
        logger = TraceArtifactLogger(
            run_id="run-123",
            output_dir=output_dir,
            training_config={"max_steps": 6, "model_name": "demo-model"},
            model_identifiers={"model_name": "demo-model", "generator_mode": "reward_aware"},
            system_prompt="You are the Solver Agent.",
            checkpoint_interval_steps=2,
        )

        manifest_path = output_dir / "logs" / "run_manifest.json"
        assert manifest_path.exists()
        manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
        assert manifest["run_id"] == "run-123"
        assert manifest["training_config"]["max_steps"] == 6

        logger.log_event(
            {
                "phase": "train",
                "step": 0,
                "train_episode_index": 1,
                "problem_id": "sum_even_numbers_1",
                "problem_family": "sum_even_numbers",
                "difficulty": "easy",
                "teacher_prompt": "Problem: sum the even numbers",
                "solver_completion": "print(sum(x for x in nums if x % 2 == 0))",
                "extracted_code": "print(sum(x for x in nums if x % 2 == 0))",
                "reward": 0.94,
                "pass_rate": 1.0,
                "visible_pass_rate": 1.0,
                "execution_status": "completed",
                "efficiency_score": 0.94,
                "optimization_hints": ["Avoid materializing temporary containers."],
                "feedback": "All hidden tests passed, but the solution can still be optimized further.",
            }
        )
        logger.record_progress(
            {
                "phase": "train",
                "completed_steps": 2,
                "total_steps": 6,
                "remaining_steps": 4,
                "progress_ratio": 0.3333,
                "current_epoch": 2.0,
                "current_difficulty": "easy",
                "curriculum_level": 1,
                "train_episode_index": 1,
                "last_problem_id": "sum_even_numbers_1",
                "last_problem_family": "sum_even_numbers",
                "last_execution_status": "completed",
            }
        )
        artifact_paths = logger.artifact_paths()

        events_path = Path(artifact_paths["events_path"])
        latest_checkpoint_path = Path(artifact_paths["latest_checkpoint_path"])
        assert events_path.exists()
        assert latest_checkpoint_path.exists()

        event_line = events_path.read_text(encoding="utf-8").strip().splitlines()[0]
        event = json.loads(event_line)
        assert event["problem_id"] == "sum_even_numbers_1"
        assert event["teacher_prompt"] == "Problem: sum the even numbers"
        assert "training_config" not in event

        checkpoint = json.loads(latest_checkpoint_path.read_text(encoding="utf-8"))
        assert checkpoint["step"] == 2
        assert checkpoint["rolling_metrics"]["avg_reward"] == 0.94
        assert "training_config" not in checkpoint

        reward_curve = output_dir / "reward_curve.csv"
        reward_curve.write_text("step,episode_reward\n0,0.94\n", encoding="utf-8")
        summary_paths = logger.finalize(
            reward_curve_csv=reward_curve,
            final_metrics={"completed_steps": 6},
        )
        summary_path = Path(summary_paths)
        assert summary_path.exists()
        summary = json.loads(summary_path.read_text(encoding="utf-8"))
        assert summary["final_metrics"]["completed_steps"] == 6

    print("Trace logging smoke tests passed")


if __name__ == "__main__":
    main()