File size: 5,333 Bytes
8541221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#!/usr/bin/env python3
"""
Minimal script to check if tasks in solved_tasks.jsonl were fully completed and verified.
Uses an LLM to assess completion status and adds the result to each row.
"""

import argparse
import json
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed

import litellm
from dotenv import load_dotenv
from pydantic import BaseModel

load_dotenv()


class CompletionCheck(BaseModel):
    reasoning: str
    completed: bool
    verified: bool


PROMPT = """You are evaluating whether an AI agent fully completed a task AND verified its completion.

Task: {question}

Agent's final answer: {solution}

Agent's trace (tool calls and responses):
{trace}

Evaluate:
1. **completed**: Did the agent actually complete the task? (not just explain what could be done, but actually do it)
2. **verified**: Did the agent verify/confirm that the task was completed correctly? (e.g., checked output, validated results, confirmed success)

Be strict:
- If the agent asked for more information or said "please provide...", it's NOT completed.
- If the agent only explained how to do something but didn't do it, it's NOT completed.
- If the agent just made a plan of how to complete it but didn't do it, it's NOT completed.
- If there's an error in the trace and no recovery, it's NOT completed.
- If the agent didn't check/confirm the code/command completed succesfully or the result is correct somehow, it's NOT verified.

Return JSON with: completed (bool), verified (bool), reasoning (brief explanation)."""


def format_trace(messages: list) -> str:
    """Format messages trace for the prompt."""
    if not messages:
        return "(No trace)"

    parts = []
    for msg in messages:
        role = msg.get("role", "unknown")
        if role == "system":
            continue

        content = msg.get("content", "")
        tool_calls = msg.get("tool_calls", [])

        if tool_calls:
            for tc in tool_calls:
                if isinstance(tc, dict) and "function" in tc:
                    name = tc["function"].get("name", "?")
                    parts.append(f"[TOOL CALL] {name}")

        if content:
            # Truncate long content
            if len(content) > 5000:
                content = content[:4000] + "..." + content[-1000:]
            parts.append(f"[{role.upper()}] {content}")

    return "\n".join(parts) if parts else "(Empty trace)"


def check_row(row: dict, model: str) -> CompletionCheck | None:
    """Check if a single task was completed and verified."""
    prompt = PROMPT.format(
        question=row["question"],
        solution=row.get("solution", "(No solution)"),
        trace=format_trace(row.get("messages", [])),
    )

    try:
        response = litellm.completion(
            model=model,
            messages=[{"role": "user", "content": prompt}],
            response_format=CompletionCheck,
            timeout=60,
        )
        return CompletionCheck.model_validate_json(response.choices[0].message.content)
    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        return None


def main():
    parser = argparse.ArgumentParser(description="Check task completion status")
    parser.add_argument("--infile", type=str, default="eval/solved_tasks.jsonl")
    parser.add_argument(
        "--outfile", type=str, default="eval/solved_tasks_checked.jsonl"
    )
    parser.add_argument(
        "--model", type=str, default="anthropic/claude-sonnet-4-5-20250929"
    )
    parser.add_argument("--max-concurrent", type=int, default=30)
    args = parser.parse_args()

    # Load data
    print(f"Loading {args.infile}...")
    rows = []
    with open(args.infile) as f:
        for line in f:
            rows.append(json.loads(line))
    print(f"Loaded {len(rows)} rows")

    # Process in parallel
    print(f"Checking completion with {args.model}...")
    with ThreadPoolExecutor(max_workers=args.max_concurrent) as executor:
        futures = {
            executor.submit(check_row, row, args.model): i for i, row in enumerate(rows)
        }
        results = [None] * len(rows)

        for future in as_completed(futures):
            idx = futures[future]
            results[idx] = future.result()
            print(
                f"Done: {sum(1 for r in results if r is not None)}/{len(rows)}",
                end="\r",
            )

    print()

    # Merge results
    output_rows = []
    for row, result in zip(rows, results):
        if result:
            row["task_completed"] = result.completed
            row["task_verified"] = result.verified
            row["completion_reasoning"] = result.reasoning
        else:
            row["task_completed"] = None
            row["task_verified"] = None
            row["completion_reasoning"] = "Error during check"
        output_rows.append(row)

    # Write output
    print(f"Writing to {args.outfile}...")
    with open(args.outfile, "w") as f:
        for row in output_rows:
            f.write(json.dumps(row, default=str) + "\n")

    # Summary
    completed = sum(1 for r in results if r and r.completed)
    verified = sum(1 for r in results if r and r.verified)
    print("\nSummary:")
    print(f"  Completed: {completed}/{len(rows)}")
    print(f"  Verified: {verified}/{len(rows)}")


if __name__ == "__main__":
    main()