File size: 2,787 Bytes
8e72e1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
Test-generation agent.

A minimal tool-calling loop (no LangGraph). The LLM is given two tools
(search_code, get_definition) and asked to generate pytest tests for a target
function. It decides which tools to call to gather the function's real source
and its dependencies, then writes the tests grounded in that actual code.

This is the "agent" capability: the model plans and acts via tools, rather than
answering in a single shot.
"""
import json
import os

from dotenv import load_dotenv
from openai import OpenAI

from src.agent.tools import TOOL_SCHEMAS

load_dotenv()

SYSTEM_PROMPT = """You are a senior Python engineer that writes pytest unit tests.

You have tools to explore a real codebase:
- search_code(query): find relevant code
- get_definition(name): fetch the full source of a function/class

Workflow:
1. Use get_definition (and search_code if needed) to read the ACTUAL source of
   the target function and anything it depends on. Never guess at the code.
2. Then write focused pytest tests covering the main behavior and edge cases.

Return ONLY the final pytest code in a single Python code block. Base the tests
strictly on the real code you retrieved.
"""

MAX_TOOL_ROUNDS = 5


class TestAgent:

    def __init__(self, tools, model="gpt-4.1-mini"):
        self.tools = tools  # a CodeTools instance
        self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        self.model = model

    def generate_tests(self, target):
        """target: a function/class name, e.g. 'create_access_token'."""
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": f"Generate pytest tests for `{target}`."},
        ]

        for _ in range(MAX_TOOL_ROUNDS):
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                tools=TOOL_SCHEMAS,
                temperature=0,
            )
            msg = response.choices[0].message

            # No tool calls -> the model produced its final answer (the tests).
            if not msg.tool_calls:
                return msg.content

            # Otherwise, run each requested tool and feed results back.
            messages.append(msg)
            for call in msg.tool_calls:
                args = json.loads(call.function.arguments)
                result = self.tools.dispatch(call.function.name, args)
                messages.append({
                    "role": "tool",
                    "tool_call_id": call.id,
                    "content": json.dumps(result)[:6000],  # keep tool payload bounded
                })

        # Safety net if it never stopped calling tools.
        return "Could not finish generating tests within the tool-call limit."