File size: 7,258 Bytes
c338ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""Tests for the inference script's parsing, prompt building, and log format."""

import pytest
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from inference import parse_llm_response, parse_fix_response, build_user_prompt, log_start, log_step, log_end


class TestParseLLMResponse:
    def test_standard_format(self):
        response = "row:1,col:name,issue:missing_value\nrow:2,col:salary,issue:wrong_type"
        issues = parse_llm_response(response)
        assert len(issues) == 2
        assert "row:1,col:name,issue:missing_value" in issues

    def test_numbered_list(self):
        response = "1. row:1,col:name,issue:missing_value\n2. row:2,col:salary,issue:wrong_type"
        issues = parse_llm_response(response)
        assert len(issues) == 2

    def test_bullet_list(self):
        response = "- row:1,col:name,issue:missing_value\n* row:2,col:salary,issue:wrong_type"
        issues = parse_llm_response(response)
        assert len(issues) == 2

    def test_equals_delimiter(self):
        response = "row=1,col=name,issue=missing_value"
        issues = parse_llm_response(response)
        assert len(issues) == 1
        assert issues[0] == "row:1,col:name,issue:missing_value"

    def test_mixed_case(self):
        response = "Row:1,Col:Name,Issue:Missing_Value"
        issues = parse_llm_response(response)
        assert len(issues) == 1
        assert issues[0] == "row:1,col:name,issue:missing_value"

    def test_empty_response(self):
        assert parse_llm_response("") == []
        assert parse_llm_response("   ") == []

    def test_garbage_lines_skipped(self):
        response = "Here are the issues:\nrow:1,col:name,issue:missing_value\nNo more issues."
        issues = parse_llm_response(response)
        assert len(issues) == 1

    def test_deduplication_not_applied(self):
        response = "row:1,col:name,issue:missing_value\nrow:1,col:name,issue:missing_value"
        issues = parse_llm_response(response)
        assert len(issues) == 2

    def test_with_column_variant(self):
        response = "row:1,column:name,issue:missing_value"
        issues = parse_llm_response(response)
        assert len(issues) == 1


class TestParseFixResponse:
    def test_standard_format(self):
        response = "row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000"
        fixes = parse_fix_response(response)
        assert len(fixes) == 2
        assert "row:4,col:name,fix:David Kim" in fixes

    def test_numbered_list(self):
        response = "1. row:4,col:name,fix:David Kim\n2. row:7,col:salary,fix:75000"
        fixes = parse_fix_response(response)
        assert len(fixes) == 2

    def test_with_special_chars(self):
        response = "row:1,col:email,fix:alice.chen@company.com"
        fixes = parse_fix_response(response)
        assert len(fixes) == 1
        assert "alice.chen@company.com" in fixes[0]

    def test_empty_response(self):
        assert parse_fix_response("") == []

    def test_date_fix(self):
        response = "row:12,col:order_date,fix:2024-01-26"
        fixes = parse_fix_response(response)
        assert len(fixes) == 1

    def test_ignores_issue_lines(self):
        response = "row:4,col:name,issue:missing_value\nrow:4,col:name,fix:David Kim"
        fixes = parse_fix_response(response)
        assert len(fixes) == 1  # only the fix line


class TestBuildUserPrompt:
    def test_includes_all_fields(self):
        obs = {
            "task_description": "Find issues",
            "schema_description": "col: int",
            "validation_rules": "no nulls",
            "dataset_csv": "a,b\n1,2",
            "num_issues_hint": 3,
            "feedback": "",
        }
        prompt = build_user_prompt(obs)
        assert "Find issues" in prompt
        assert "col: int" in prompt
        assert "no nulls" in prompt
        assert "a,b" in prompt
        assert "3 issues" in prompt

    def test_includes_feedback_on_retry(self):
        obs = {
            "task_description": "Find issues",
            "schema_description": "",
            "validation_rules": "",
            "dataset_csv": "a\n1",
            "num_issues_hint": 0,
            "feedback": "Step 1/3: You missed 2 issues",
        }
        prompt = build_user_prompt(obs)
        assert "FEEDBACK" in prompt
        assert "missed 2" in prompt

    def test_excludes_reset_feedback(self):
        obs = {
            "task_description": "",
            "schema_description": "",
            "validation_rules": "",
            "dataset_csv": "",
            "num_issues_hint": 0,
            "feedback": "Environment reset. Start inspecting.",
        }
        prompt = build_user_prompt(obs)
        assert "FEEDBACK" not in prompt

    def test_include_fixes_flag(self):
        obs = {
            "task_description": "Find issues",
            "schema_description": "",
            "validation_rules": "",
            "dataset_csv": "a\n1",
            "num_issues_hint": 0,
            "feedback": "",
        }
        prompt = build_user_prompt(obs, include_fixes=True)
        assert "fix" in prompt.lower()


class TestLogFormat:
    """Verify stdout log format matches hackathon evaluation requirements."""

    def test_log_start_format(self, capsys):
        log_start(task="easy", env="dataqa_env", model="test-model")
        out = capsys.readouterr().out.strip()
        assert out == "[START] task=easy env=dataqa_env model=test-model"

    def test_log_step_format(self, capsys):
        log_step(step=1, action="row:1,col:name,issue:missing_value", reward=0.50, done=False, error=None)
        out = capsys.readouterr().out.strip()
        assert out == "[STEP] step=1 action=row:1,col:name,issue:missing_value reward=0.50 done=false error=null"

    def test_log_step_with_error(self, capsys):
        log_step(step=2, action="none", reward=0.00, done=True, error="timeout")
        out = capsys.readouterr().out.strip()
        assert "error=timeout" in out
        assert "done=true" in out

    def test_log_end_format(self, capsys):
        log_end(success=True, steps=3, score=0.85, rewards=[0.25, 0.50, 0.85])
        out = capsys.readouterr().out.strip()
        assert out == "[END] success=true steps=3 score=0.850 rewards=0.25,0.50,0.85"

    def test_log_end_failure(self, capsys):
        log_end(success=False, steps=1, score=0.0, rewards=[0.0])
        out = capsys.readouterr().out.strip()
        assert "success=false" in out
        assert "score=0.000" in out

    def test_reward_format_2_decimal(self, capsys):
        log_step(step=1, action="test", reward=0.123456, done=False, error=None)
        out = capsys.readouterr().out.strip()
        assert "reward=0.12" in out

    def test_no_newlines_within_line(self, capsys):
        log_start(task="easy", env="dataqa_env", model="model")
        log_step(step=1, action="act", reward=0.0, done=False, error=None)
        log_end(success=False, steps=1, score=0.0, rewards=[0.0])
        out = capsys.readouterr().out
        lines = [l for l in out.split("\n") if l.strip()]
        assert len(lines) == 3
        assert lines[0].startswith("[START]")
        assert lines[1].startswith("[STEP]")
        assert lines[2].startswith("[END]")