SecureCodeEnv / inference.py
vishaldhakad's picture
change in reward system to strict it between the 0-1
791664b
Raw
History Blame Contribute Delete
5.06 kB
"""
SecureCodeEnv - Baseline Inference Script
Required by hackathon. Runs an LLM agent through the environment.
Outputs clamped [START]/[STEP]/[END] blocks to pass range validation.
"""
import os
import json
import time
import sys
import requests
from openai import OpenAI
from typing import Dict, List, Any
# ── Configuration ──────────────────────────────────────────────────────────
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/")
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "sk-placeholder")
def clamp_score(score: float) -> float:
"""Ensures score is strictly between 0 and 1 (0.001 … 0.999)."""
epsilon = 0.001
try:
v = float(score)
except (TypeError, ValueError):
return 0.5
if v != v: # NaN
return 0.5
return max(epsilon, min(1.0 - epsilon, v))
def clean_code(raw: str) -> str:
"""Removes markdown code fences safely."""
lines = [line for line in raw.splitlines()
if not line.strip().startswith("```")]
return "\n".join(lines).strip()
SYSTEM_PROMPT = """You are a senior Python security engineer.
Output ONLY raw Python code β€” no markdown, no explanations.
Your code must:
1. Solve the problem correctly
2. Resist SQL injection, path traversal, and auth bypass attacks
3. Use parameterized queries β€” never f-string SQL
4. Use secrets module (not random) for tokens
5. Use bcrypt (not hashlib) for passwords
6. Use hmac.compare_digest for secret comparison
7. Have type hints and docstrings on every function"""
def run_episode(difficulty: str) -> None:
"""Runs one episode and prints [START], [STEP], [END] blocks."""
try:
r = requests.post(
f"{ENV_URL}/reset",
json={"difficulty": difficulty},
timeout=30,
)
r.raise_for_status()
data = r.json()
except Exception as e:
print(f"Failed to reset {difficulty}: {e}", file=sys.stderr)
return
sid = data["session_id"]
tid = data["task_id"]
print(f"[START] task={tid} difficulty={difficulty}", flush=True)
final_score = clamp_score(0.0) # starts at epsilon, not 0.0
total_steps = 0
for i in range(1, 6):
total_steps = i
context_str = json.dumps(data.get("codegraph", {}))[:2000]
prev_fb = data.get("last_feedback", "")
user_msg = (
f"Task: {data['problem_statement']}\n\n"
f"Security targets: {data.get('cwe_targets', [])}\n\n"
f"Codebase context:\n{context_str}"
)
if prev_fb:
user_msg += f"\n\nPrevious feedback:\n{prev_fb}"
user_msg += "\n\nWrite the complete Python implementation now:"
try:
resp = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_msg},
],
max_tokens=1500,
temperature=0.1,
)
code = clean_code(resp.choices[0].message.content or "")
if not code.strip():
code = "def placeholder(): pass"
step_r = requests.post(
f"{ENV_URL}/step",
json={
"session_id": sid,
"code": code,
"filename": f"step_{i}.py",
"task_id": tid,
},
timeout=65,
)
step_r.raise_for_status()
res = step_r.json()
raw_reward = res.get("total_reward", 0.0)
clamped = clamp_score(raw_reward)
final_score = clamped
print(f"[STEP] step={i} reward={clamped:.4f}", flush=True)
if res.get("done"):
break
# Feed updated context back for next step
data["codegraph"] = res.get("codegraph", {})
data["last_feedback"] = res.get("feedback", {}).get("summary", "")
except Exception as e:
print(f"Error in step {i}: {e}", file=sys.stderr)
# Don't break β€” try remaining steps
time.sleep(1)
print(f"[END] task={tid} score={final_score:.4f} steps={total_steps}", flush=True)
def main():
# Health check
try:
requests.get(f"{ENV_URL}/health", timeout=10).raise_for_status()
print(f"Environment healthy: {ENV_URL}", file=sys.stderr)
except Exception as e:
print(f"Health check failed: {e}", file=sys.stderr)
sys.exit(1)
for diff in ["easy", "medium", "hard"]:
run_episode(diff)
time.sleep(2)
if __name__ == "__main__":
main()