File size: 9,902 Bytes
ac224ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af29724
ac224ce
 
 
af29724
 
ac224ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
"""
tests/test_stdout_format.py
---------------------------
Tests for the mandatory stdout logging format defined in inference.py.

The evaluation harness parses [START], [STEP], and [END] lines from stdout.
Any deviation in field names, ordering, or formatting causes incorrect scoring.

These tests parse the actual output of the logging functions and verify:
- Field names match exactly.
- Fields appear in the required order.
- Numeric formatting is correct (reward: 2dp, score: 3dp).
- Boolean values are lowercase.
- No newlines appear within a log line.
"""

import re
import sys
import io
import pytest

# Patch sys path so inference.py can be imported from the project root.
import importlib.util
from pathlib import Path

# Import the logging functions directly from inference.py
_INFERENCE_PATH = Path(__file__).parent.parent / "inference.py"
_spec = importlib.util.spec_from_file_location("inference", _INFERENCE_PATH)
_inf = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(_inf)

log_start = _inf.log_start
log_step = _inf.log_step
log_end = _inf.log_end


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _capture_stdout(fn, *args, **kwargs) -> str:
    """Capture what fn() prints to stdout and return it as a string."""
    buf = io.StringIO()
    old_stdout = sys.stdout
    sys.stdout = buf
    try:
        fn(*args, **kwargs)
    finally:
        sys.stdout = old_stdout
    return buf.getvalue().rstrip("\n")


def _parse_fields(line: str) -> dict:
    """
    Parse a log line like '[TAG] key=val key2=val2 ...' into a dict.
    Values can contain letters, digits, underscores, dots, commas, hyphens,
    slashes, colons, and curly braces — everything up to the next space=word pair.
    """
    # Strip tag
    line = re.sub(r"^\[.*?\]\s*", "", line)
    # Tokenise: split on whitespace-separated key=value pairs
    pattern = re.compile(r"(\w+)=(\S+)")
    return {m.group(1): m.group(2) for m in pattern.finditer(line)}


# ---------------------------------------------------------------------------
# [START] line
# ---------------------------------------------------------------------------

class TestStartLine:

    def test_prefix(self):
        line = _capture_stdout(log_start, task="task_1", env="rag_debug_env",
                                model="Qwen/Qwen2.5-72B")
        assert line.startswith("[START]"), f"Expected [START] prefix, got: {line!r}"

    def test_field_names(self):
        line = _capture_stdout(log_start, task="task_1", env="rag_debug_env",
                                model="Qwen/Qwen2.5-72B")
        fields = _parse_fields(line)
        assert "task" in fields, "Missing 'task' field"
        assert "env" in fields, "Missing 'env' field"
        assert "model" in fields, "Missing 'model' field"

    def test_field_order(self):
        line = _capture_stdout(log_start, task="task_1", env="rag_debug_env",
                                model="Qwen/Qwen2.5-72B")
        # Verify positional ordering via index search
        assert line.index("task=") < line.index("env=") < line.index("model="), (
            "Fields must appear in order: task, env, model"
        )

    def test_field_values(self):
        line = _capture_stdout(log_start, task="task_1", env="rag_debug_env",
                                model="Qwen/Qwen2.5-72B")
        fields = _parse_fields(line)
        assert fields["task"] == "task_1"
        assert fields["env"] == "rag_debug_env"

    def test_single_line(self):
        line = _capture_stdout(log_start, task="task_1", env="rag_debug_env",
                                model="Qwen/Qwen2.5-72B")
        assert "\n" not in line, "Log line must not contain internal newlines"


# ---------------------------------------------------------------------------
# [STEP] line
# ---------------------------------------------------------------------------

class TestStepLine:

    def _step_line(self, step=1, action="submit()", reward=0.0, done=False, error=None):
        return _capture_stdout(log_step, step=step, action=action,
                               reward=reward, done=done, error=error)

    def test_prefix(self):
        assert self._step_line().startswith("[STEP]")

    def test_field_names(self):
        line = self._step_line()
        fields = _parse_fields(line)
        for field in ("step", "action", "reward", "done", "error"):
            assert field in fields, f"Missing '{field}' field in [STEP] line"

    def test_field_order(self):
        line = self._step_line()
        positions = {
            "step": line.index("step="),
            "action": line.index("action="),
            "reward": line.index("reward="),
            "done": line.index("done="),
            "error": line.index("error="),
        }
        ordered = sorted(positions, key=positions.get)
        assert ordered == ["step", "action", "reward", "done", "error"], (
            f"Field order wrong: {ordered}"
        )

    def test_reward_two_decimal_places(self):
        line = self._step_line(reward=0.123456)
        m = re.search(r"reward=(\d+\.\d+)", line)
        assert m is not None, "reward field not found"
        assert len(m.group(1).split(".")[1]) == 2, (
            f"reward should have 2 decimal places, got: {m.group(1)!r}"
        )

    def test_reward_exact_format(self):
        line = self._step_line(reward=0.5)
        assert "reward=0.50" in line, f"Expected reward=0.50, got: {line!r}"

    def test_done_false_lowercase(self):
        line = self._step_line(done=False)
        assert "done=false" in line, "done=False must be serialized as 'false'"

    def test_done_true_lowercase(self):
        line = self._step_line(done=True)
        assert "done=true" in line, "done=True must be serialized as 'true'"

    def test_error_null_when_none(self):
        line = self._step_line(error=None)
        assert "error=null" in line, "null error must be serialized as 'null'"

    def test_error_string_when_present(self):
        line = self._step_line(error="Invalid value")
        assert "error=" in line
        assert "null" not in line.split("error=")[1].split()[0], (
            "error field should contain the message, not 'null'"
        )

    def test_single_line(self):
        line = self._step_line(action="adjust_threshold(value=0.15)")
        assert "\n" not in line

    @pytest.mark.parametrize("reward", [0.0, 0.5, 1.0, 0.123])
    def test_various_rewards_in_range(self, reward):
        line = self._step_line(reward=reward)
        m = re.search(r"reward=(\d+\.\d{2})", line)
        assert m is not None, f"reward field missing or wrong format for reward={reward}"
        val = float(m.group(1))
        assert 0.0 <= val <= 1.0


# ---------------------------------------------------------------------------
# [END] line
# ---------------------------------------------------------------------------

class TestEndLine:

    def _end_line(self, success=True, steps=5, score=0.85, rewards=None):
        rewards = rewards or [0.3, 0.5, 0.7, 0.85, 0.90]
        return _capture_stdout(log_end, success=success, steps=steps,
                               score=score, rewards=rewards)

    def test_prefix(self):
        assert self._end_line().startswith("[END]")

    def test_field_names(self):
        line = self._end_line()
        fields = _parse_fields(line)
        for field in ("success", "steps", "score", "rewards"):
            assert field in fields, f"Missing '{field}' field in [END] line"

    def test_field_order(self):
        line = self._end_line()
        positions = {
            "success": line.index("success="),
            "steps": line.index("steps="),
            "score": line.index("score="),
            "rewards": line.index("rewards="),
        }
        ordered = sorted(positions, key=positions.get)
        assert ordered == ["success", "steps", "score", "rewards"], (
            f"Field order wrong: {ordered}"
        )

    def test_success_true_lowercase(self):
        line = self._end_line(success=True)
        assert "success=true" in line

    def test_success_false_lowercase(self):
        line = self._end_line(success=False)
        assert "success=false" in line

    def test_score_two_decimal_places(self):
        line = self._end_line(score=0.85)
        m = re.search(r"score=(\d+\.\d+)", line)
        assert m is not None, "score field not found"
        assert len(m.group(1).split(".")[1]) == 2, (
            f"score should have 2 decimal places, got: {m.group(1)!r}"
        )

    def test_rewards_comma_separated(self):
        line = self._end_line(rewards=[0.3, 0.5, 0.7])
        m = re.search(r"rewards=(\S+)", line)
        assert m is not None, "rewards field not found"
        parts = m.group(1).split(",")
        assert len(parts) == 3

    def test_rewards_two_decimal_places(self):
        line = self._end_line(rewards=[0.123, 0.456])
        m = re.search(r"rewards=(\S+)", line)
        assert m is not None
        for part in m.group(1).split(","):
            decimal_part = part.split(".")[1] if "." in part else ""
            assert len(decimal_part) == 2, (
                f"Each reward in rewards should have 2 decimal places, got: {part!r}"
            )

    def test_score_in_unit_interval(self):
        for score in [0.0, 0.5, 1.0]:
            line = self._end_line(score=score)
            m = re.search(r"score=(\d+\.\d+)", line)
            assert m is not None
            assert 0.0 <= float(m.group(1)) <= 1.0

    def test_single_line(self):
        line = self._end_line()
        assert "\n" not in line

    def test_empty_rewards_list(self):
        """Edge case: no steps taken should produce empty rewards."""
        line = self._end_line(rewards=[])
        assert "rewards=" in line