File size: 8,576 Bytes
d586ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2916eb9
 
d586ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2916eb9
d586ce5
 
 
 
 
 
 
2916eb9
d586ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2916eb9
d586ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
#!/usr/bin/env python3
from __future__ import annotations

import asyncio
import json
import os
from typing import Any, Dict, List, Optional

from openai import OpenAI

try:
    from code_security_auditor_env import CodeSecurityAction, CodeSecurityAuditorEnv
except ImportError:
    from client import CodeSecurityAuditorEnv
    from models import CodeSecurityAction

API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
ENV_BASE_URL = os.getenv("ENV_BASE_URL")
DEFAULT_ENV_BASE_URL = os.getenv("DEFAULT_ENV_BASE_URL", "http://127.0.0.1:8000")
DEFAULT_LOCAL_IMAGE_NAME = os.getenv("DEFAULT_LOCAL_IMAGE_NAME", "code-security-auditor-env:latest")
TASK_IDS = [t.strip() for t in os.getenv("TASK_IDS", "easy,medium,hard").split(",") if t.strip()]
MAX_STEPS = int(os.getenv("MAX_STEPS", "12"))
TEMPERATURE = 0.0
MAX_TOKENS = 260
BENCHMARK = "code_security_auditor_env"
MIN_STRICT_SCORE = 0.001
MAX_STRICT_SCORE = 0.999

SYSTEM_PROMPT = (
    "You are a senior application security reviewer. Produce strictly valid JSON for the next action. "
    "Allowed action_type values: inspect_file, submit_finding, submit_final_report. "
    "Do not include markdown fences. Keep fields concise and accurate."
)


def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
    err = error if error else "null"
    print(
        f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={err}",
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    print(
        f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
        flush=True,
    )


def _compact_action_str(action: Dict[str, Any]) -> str:
    return json.dumps(action, separators=(",", ":"), ensure_ascii=True)


def _default_action() -> Dict[str, Any]:
    return {
        "action_type": "submit_final_report",
        "confidence": 0.5,
        "summary": "fallback-finalize",
        "evidence": "fallback-finalize",
    }


def _safe_error(exc: Exception) -> str:
    msg = str(exc).strip()
    if not msg:
        msg = exc.__class__.__name__
    return msg.replace("\n", " ")[:240]


def _parse_action(raw: str, available_files: List[str]) -> Dict[str, Any]:
    try:
        parsed = json.loads(raw)
        if not isinstance(parsed, dict):
            return _default_action()
    except Exception:
        return _default_action()

    action_type = parsed.get("action_type")
    if action_type not in {"inspect_file", "submit_finding", "submit_final_report"}:
        return _default_action()

    action: Dict[str, Any] = {
        "action_type": action_type,
        "confidence": float(parsed.get("confidence", 0.5)),
        "summary": str(parsed.get("summary", ""))[:400],
        "evidence": str(parsed.get("evidence", ""))[:700],
    }

    if parsed.get("filename"):
        filename = str(parsed["filename"])
        if filename in available_files:
            action["filename"] = filename
    if parsed.get("line_start") is not None:
        try:
            action["line_start"] = max(1, int(parsed["line_start"]))
        except Exception:
            pass
    if parsed.get("line_end") is not None:
        try:
            action["line_end"] = max(1, int(parsed["line_end"]))
        except Exception:
            pass
    if parsed.get("vuln_type") is not None:
        action["vuln_type"] = str(parsed["vuln_type"])
    if parsed.get("severity") is not None:
        action["severity"] = str(parsed["severity"])

    action["confidence"] = min(1.0, max(0.0, action["confidence"]))

    return action


def _build_prompt(obs: Any, step: int) -> str:
    findings = obs.findings_so_far[-4:] if obs.findings_so_far else []
    snippet = obs.file_excerpt[:1800] if obs.file_excerpt else ""
    return (
        f"Task: {obs.task_id} ({obs.difficulty})\\n"
        f"Objective: {obs.objective}\\n"
        f"Step: {step}\\n"
        f"Steps remaining: {obs.steps_remaining}\\n"
        f"Files: {', '.join(obs.available_files)}\\n"
        f"Last feedback: {obs.last_feedback}\\n"
        f"Focused file: {obs.focused_file}\\n"
        f"Recent findings: {json.dumps(findings)}\\n"
        f"Visible snippet:\\n{snippet}\\n"
        "Return one JSON object with action_type and required fields."
    )


def _query_model(client: OpenAI, obs: Any, step: int) -> Dict[str, Any]:
    user_prompt = _build_prompt(obs, step)
    try:
        resp = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_prompt},
            ],
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS,
            stream=False,
        )
        content = (resp.choices[0].message.content or "").strip()
        return _parse_action(content, obs.available_files)
    except Exception:
        return _default_action()


async def _create_env() -> CodeSecurityAuditorEnv:
    # Prefer explicit configuration, then fall back to common local defaults.
    if ENV_BASE_URL:
        return CodeSecurityAuditorEnv(base_url=ENV_BASE_URL)

    if LOCAL_IMAGE_NAME:
        return await CodeSecurityAuditorEnv.from_docker_image(LOCAL_IMAGE_NAME)

    try:
        return CodeSecurityAuditorEnv(base_url=DEFAULT_ENV_BASE_URL)
    except Exception:
        return await CodeSecurityAuditorEnv.from_docker_image(DEFAULT_LOCAL_IMAGE_NAME)


async def run_task(env: CodeSecurityAuditorEnv, client: OpenAI, task_id: str) -> float:
    log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)

    rewards: List[float] = []
    steps_taken = 0
    score = 0.0
    success = False

    try:
        result = await env.reset(task_id=task_id)
        obs = result.observation

        for step in range(1, MAX_STEPS + 1):
            if result.done:
                break

            action_dict = _query_model(client, obs, step)
            action_str = _compact_action_str(action_dict)

            action = CodeSecurityAction(**action_dict)
            result = await env.step(action)
            obs = result.observation

            reward = float(result.reward or 0.0)
            done = bool(result.done)
            error = obs.metadata.get("last_action_error")

            rewards.append(reward)
            steps_taken = step
            log_step(step=step, action=action_str, reward=reward, done=done, error=error)

            if done:
                break

        score = float(obs.reward or 0.0)
        score = min(max(score, MIN_STRICT_SCORE), MAX_STRICT_SCORE)
        success = score >= 0.6
    except Exception as exc:
        # Keep evaluator contract: do not crash inference.py on transient/runtime errors.
        log_step(step=max(1, steps_taken), action="{}", reward=0.0, done=True, error=_safe_error(exc))
        if not rewards:
            rewards.append(0.0)
        steps_taken = max(1, steps_taken)
        score = MIN_STRICT_SCORE
        success = False
    finally:
        log_end(success=success, steps=steps_taken, score=score, rewards=rewards)

    return score


async def main() -> None:
    # Keep script resilient in validators even if a key is temporarily unavailable.
    api_key = API_KEY or "missing"

    client = OpenAI(base_url=API_BASE_URL, api_key=api_key)

    try:
        env = await _create_env()
    except Exception as exc:
        # Emit structured logs for each task and exit cleanly.
        err = _safe_error(exc)
        for task_id in TASK_IDS:
            log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
            log_step(step=1, action="{}", reward=0.0, done=True, error=err)
            log_end(success=False, steps=1, score=MIN_STRICT_SCORE, rewards=[MIN_STRICT_SCORE])
        return

    try:
        scores: List[float] = []
        for task_id in TASK_IDS:
            score = await run_task(env, client, task_id)
            scores.append(score)

        # Keep strict output format requirement: no extra structured tags beyond START/STEP/END.
        _ = scores
    finally:
        await env.close()


if __name__ == "__main__":
    asyncio.run(main())