python_env / tests /test_env.py
darshanajudiya7's picture
Upload folder using huggingface_hub
d25ab77 verified
from __future__ import annotations
from fastapi.testclient import TestClient
import pytest
from models import (
ActionType,
IssueType,
PythonReviewAction,
Severity,
)
from server.app import app
from server.grading import grade_review
from server.python_env_environment import PythonEnvironment
from server.task_bank import load_task_bank
def _snippet_by_id(task_id: str, snippet_id: str):
return next(item for item in load_task_bank()[task_id] if item.snippet_id == snippet_id)
def test_add_comment_requires_fields() -> None:
with pytest.raises(Exception):
PythonReviewAction(action_type=ActionType.ADD_COMMENT)
def test_approve_rejects_extra_fields() -> None:
with pytest.raises(Exception):
PythonReviewAction(
action_type=ActionType.APPROVE,
comment="looks good",
)
def test_easy_grader_rewards_required_issue_and_request_changes() -> None:
snippet = load_task_bank()["task_easy"][0]
history = [
PythonReviewAction(
action_type=ActionType.ADD_COMMENT,
line_number=4,
issue_type=IssueType.STYLE,
severity=Severity.LOW,
comment="Ambiguous variable name l violates PEP8 E741.",
),
PythonReviewAction(action_type=ActionType.REQUEST_CHANGES),
]
comments = []
for step, action in enumerate(history, start=1):
comments.append(
{
"step_index": step,
"action_type": action.action_type,
"line_number": action.line_number,
"issue_type": action.issue_type,
"severity": action.severity,
"comment": action.comment,
}
)
from models import ReviewComment
result = grade_review(
"task_easy",
snippet,
[ReviewComment.model_validate(item) for item in comments],
duplicate_comments=0,
)
assert result.score > 0.35
assert result.required_found >= 1
def test_hard_grader_rewards_security_metadata() -> None:
snippet = load_task_bank()["task_hard"][0]
from models import ReviewComment
review = ReviewComment(
step_index=1,
action_type=ActionType.ADD_COMMENT,
line_number=2,
issue_type=IssueType.SECURITY,
severity=Severity.CRITICAL,
comment="SQL injection risk. This is an OWASP injection issue because the query interpolates user input.",
suggestion="Use a parameterized query with placeholders instead of string interpolation.",
)
result = grade_review("task_hard", snippet, [review], duplicate_comments=0)
assert result.score > 0.30
assert result.true_positives == 1
def test_environment_step_updates_metrics() -> None:
env = PythonEnvironment()
observation = env.reset(task_id="task_easy").model_copy()
snippet = _snippet_by_id("task_easy", observation.snippet_id)
issue = next(item for item in snippet.gold_issues if item.required)
next_observation = env.step(
PythonReviewAction(
action_type=ActionType.ADD_COMMENT,
line_number=issue.line,
issue_type=issue.issue_type,
severity=issue.severity,
comment=issue.description,
)
)
assert next_observation.reward is not None
assert next_observation.metrics.true_positives >= 1
assert next_observation.review_history[-1].matched_issue_ids
def test_environment_terminal_action_sets_done() -> None:
env = PythonEnvironment()
observation = env.reset(task_id="task_easy")
result = env.step(PythonReviewAction(action_type=ActionType.REQUEST_CHANGES))
assert result.done is True
assert result.metrics.current_score >= 0.0
def test_api_smoke_endpoints() -> None:
client = TestClient(app)
reset_response = client.post("/reset", json={"task_id": "task_easy"})
assert reset_response.status_code == 200
payload = reset_response.json()
assert payload["observation"]["task_id"] == "task_easy"
snippet = _snippet_by_id("task_easy", payload["observation"]["snippet_id"])
issue = next(item for item in snippet.gold_issues if item.required)
step_response = client.post(
"/step",
json={
"action": {
"action_type": "ADD_COMMENT",
"line_number": issue.line,
"issue_type": issue.issue_type.value,
"severity": issue.severity.value,
"comment": issue.description,
}
},
)
assert step_response.status_code == 200
assert step_response.json()["observation"]["metrics"]["true_positives"] >= 1
tasks_response = client.get("/tasks")
assert tasks_response.status_code == 200
assert len(tasks_response.json()["tasks"]) == 3
metrics_response = client.get("/metrics")
assert metrics_response.status_code == 200
assert "metrics" in metrics_response.json()
health_response = client.get("/health")
assert health_response.status_code == 200
assert health_response.json()["status"] == "ok"
schema_response = client.get("/schema")
assert schema_response.status_code == 200
assert "action" in schema_response.json()