File size: 9,421 Bytes
d25ab77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""Baseline inference script for the Python code-review environment."""

from __future__ import annotations

import asyncio
import json
import os
import re
from pathlib import Path
from typing import Any, Dict, List, Optional

from openai import OpenAI

from client import PythonEnv
from models import ActionType, PythonReviewAction


# Read all runtime configuration from environment variables so the script can
# be reused unchanged across local runs, CI, and HF Spaces validation.
API_BASE_URL = os.environ["API_BASE_URL"]
MODEL_NAME = os.environ["MODEL_NAME"]
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
ENV_BASE_URL = os.getenv("ENV_BASE_URL")
DOCKER_IMAGE = os.getenv("PYTHON_ENV_IMAGE", "python_env-env:latest")
MAX_STEPS = int(os.getenv("MAX_STEPS", "25"))
REPORT_PATH = Path(os.getenv("INFERENCE_REPORT_PATH", "inference_results.json"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0"))
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "900"))
TASK_IDS = ["task_easy", "task_medium", "task_hard"]

SYSTEM_PROMPT = """You are a precise senior Python code reviewer.
Return strict JSON using this schema:
{
  "action_type": "ADD_COMMENT|APPROVE|REQUEST_CHANGES|ASK_CONTEXT|SKIP_LINE",
  "line_number": 1,
  "issue_type": "STYLE|LOGIC|SECURITY|PERFORMANCE|DOCS",
  "severity": "LOW|MEDIUM|HIGH|CRITICAL",
  "comment": "why this matters",
  "suggestion": "optional fix suggestion",
  "question": "optional context question"
}

Rules:
- Output JSON only. No markdown fences.
- Only report issues supported by the visible code.
- Use one action per step.
- Prefer high precision over quantity.
- Use REQUEST_CHANGES once you believe the code should be rejected.
- Use APPROVE only when the snippet is genuinely clean.
"""


def _build_prompt(observation, step: int, history: List[str]) -> str:
    """Build the task prompt sent to the model for one step."""

    numbered_lines = "\n".join(
        f"{index + 1:>3}: {line}" for index, line in enumerate(observation.lines)
    )
    history_text = "\n".join(history[-4:]) if history else "No previous attempts."
    return (
        f"Task ID: {observation.task_id}\n"
        f"Step: {step}\n"
        f"Current score: {observation.metrics.current_score:.2f}\n"
        f"Last reward: {observation.reward_summary.step_reward:.2f}\n"
        f"Cumulative reward: {observation.reward_summary.cumulative_reward:.2f}\n"
        f"Latest feedback: {observation.feedback or 'None'}\n"
        f"Attempt history:\n{history_text}\n\n"
        f"Filename: {observation.filename}\n"
        f"Context: {observation.context or 'None'}\n"
        "Code to review:\n"
        f"{numbered_lines}"
    )


def _extract_text_content(message_content: Any) -> str:
    """Normalize OpenAI response content into one text string."""

    if isinstance(message_content, str):
        return message_content
    if isinstance(message_content, list):
        parts: List[str] = []
        for item in message_content:
            if isinstance(item, dict):
                text = item.get("text")
                if isinstance(text, str):
                    parts.append(text)
        return "\n".join(parts)
    return ""


def _extract_json_blob(content: str) -> str:
    """Extract a JSON object from plain or fenced model output."""

    fenced_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", content, re.DOTALL)
    if fenced_match:
        return fenced_match.group(1)

    start = content.find("{")
    end = content.rfind("}")
    if start != -1 and end != -1 and end > start:
        return content[start : end + 1]
    return content


def _parse_response(content: str) -> Dict[str, Any]:
    """Parse the model response into a normalized payload dict."""

    raw = _extract_json_blob(content)
    try:
        data = json.loads(raw)
    except json.JSONDecodeError:
        return {"_parse_error": raw}
    return data


def _completion(client: OpenAI, prompt: str) -> Dict[str, Any]:
    """Send one completion request to the configured model endpoint."""

    response = client.chat.completions.create(
        model=MODEL_NAME,
        temperature=TEMPERATURE,
        max_tokens=MAX_TOKENS,
        messages=[
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": prompt},
        ],
    )
    content = _extract_text_content(response.choices[0].message.content) or "{}"
    return _parse_response(content)


def _build_fallback_action(observation, note: str) -> PythonReviewAction:
    """Create a safe fallback action when model output is unusable."""

    return PythonReviewAction(
        action_type=ActionType.REQUEST_CHANGES
        if observation.current_step + 1 >= observation.max_steps
        else ActionType.ASK_CONTEXT,
        question=note if observation.current_step + 1 < observation.max_steps else None,
    )


def _to_action(
    payload: Dict[str, Any],
    observation,
) -> PythonReviewAction:
    """Convert a parsed model payload into a valid environment action."""

    try:
        return PythonReviewAction.model_validate(payload)
    except Exception:
        note = "Model returned no valid action."
        if payload.get("_parse_error"):
            note = f"{note} Raw response could not be parsed as JSON."
        return _build_fallback_action(observation, note)


def _make_env():
    """Connect to a live environment or launch the Docker image."""

    if ENV_BASE_URL:
        return PythonEnv(base_url=ENV_BASE_URL).sync()
    return asyncio.run(PythonEnv.from_docker_image(DOCKER_IMAGE)).sync()


def _task_result_dict(observation, step_logs: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Build the report payload for one completed task run."""

    return {
        "task_id": observation.task_id,
        "snippet_id": observation.snippet_id,
        "score": observation.metrics.current_score,
        "precision": observation.metrics.precision,
        "recall": observation.metrics.recall,
        "f1": observation.metrics.f1,
        "true_positives": observation.metrics.true_positives,
        "false_positives": observation.metrics.false_positives,
        "missed_issues": observation.metrics.missed_issues,
        "cumulative_reward": observation.metrics.cumulative_reward,
        "steps": step_logs,
    }


def main() -> None:
    """Run the configured model against the benchmark task set."""

    if not API_KEY:
        raise RuntimeError("Set HF_TOKEN or OPENAI_API_KEY before running inference.py")

    client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
    env = _make_env()
    episode_results: List[Dict[str, Any]] = []

    try:
        for index, task_id in enumerate(TASK_IDS, start=1):
            result = env.reset(task_id=task_id)
            observation = result.observation
            history: List[str] = []
            step_logs: List[Dict[str, Any]] = []

            print(f"Task {index}: {task_id} ({observation.snippet_id})")

            for step in range(1, MAX_STEPS + 1):
                prompt = _build_prompt(observation, step, history)
                try:
                    payload = _completion(client, prompt)
                except Exception as exc:
                    payload = {"_error": str(exc)}

                action = _to_action(payload=payload, observation=observation)

                result = env.step(action)
                observation = result.observation

                step_log = {
                    "step": step,
                    "action_type": action.action_type.value,
                    "line_number": action.line_number,
                    "reward": result.reward or 0.0,
                    "score": observation.metrics.current_score,
                    "done": result.done,
                    "feedback": observation.feedback,
                }
                if payload.get("_error"):
                    step_log["model_error"] = payload["_error"]
                if payload.get("_parse_error"):
                    step_log["parse_error"] = True
                step_logs.append(step_log)

                history.append(
                    f"step={step} action={action.action_type.value} "
                    f"line={action.line_number} score={observation.metrics.current_score:.2f} "
                    f"reward={(result.reward or 0.0):.2f} feedback={observation.feedback}"
                )

                print(
                    f"  step={step} action={action.action_type.value} "
                    f"score={observation.metrics.current_score:.2f} reward={(result.reward or 0.0):.2f} "
                    f"done={result.done}"
                )

                if result.done:
                    break

            episode_results.append(_task_result_dict(observation, step_logs))
    finally:
        env.close()

    mean_score = sum(item["score"] for item in episode_results) / len(episode_results) if episode_results else 0.0
    summary = {
        "model_name": MODEL_NAME,
        "api_base_url": API_BASE_URL,
        "task_count": len(episode_results),
        "mean_score": mean_score,
        "results": episode_results,
    }

    REPORT_PATH.write_text(json.dumps(summary, indent=2), encoding="utf-8")
    print(json.dumps(summary, indent=2))
    print(f"\nSaved report to {REPORT_PATH}")


if __name__ == "__main__":
    main()