Openenv / inference.py
vishaldhakad's picture
intial push
eda351c
Raw
History Blame Contribute Delete
8.82 kB
"""
inference.py β€” Baseline inference script (REQUIRED by hackathon).
CRITICAL requirements:
- Must use OpenAI client (hackathon rule β€” Groq/Gemini both support it)
- Must complete in < 20 minutes on 2 vCPU / 8GB RAM
- Must be in project root
- env vars: API_BASE_URL, MODEL_NAME, HF_TOKEN, ENV_URL
Compatible with:
- Groq free tier: API_BASE_URL=https://api.groq.com/openai/v1, MODEL_NAME=llama-3.3-70b-versatile
- Gemini Flash: API_BASE_URL=https://generativelanguage.googleapis.com/v1beta/openai, MODEL_NAME=gemini-2.5-flash
- OpenAI: API_BASE_URL=https://api.openai.com/v1, MODEL_NAME=gpt-4o-mini
"""
import os
import json
import time
import requests
from openai import OpenAI
# ── Config (from environment variables) ──────────────────────────────────────
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.groq.com/openai/v1")
MODEL_NAME = os.environ.get("MODEL_NAME", "llama-3.3-70b-versatile")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "dummy")
# ── System prompt ─────────────────────────────────────────────────────────────
SYSTEM_PROMPT = """You are a Python security engineer writing production-ready, secure Python code.
When given a task, write ONLY the Python function β€” no explanations, no markdown fences, no comments outside the function.
Your code MUST:
1. Solve the problem correctly β€” handle None, empty string, boundary values
2. Resist security attacks: SQL injection, path traversal, auth bypass, XSS
3. Use PARAMETERISED queries β€” NEVER string-format user input into SQL
4. Validate and sanitise ALL inputs before use
5. Use proper type hints on all function signatures
6. Have a docstring explaining what the function does
7. Use try/except with specific exception types (not bare except)
8. Follow the naming and error-handling conventions shown in CODEBASE CONTEXT
9. Import only well-known standard library or PyPI packages
CRITICAL SECURITY RULES:
- SQL: always use cursor.execute(sql, (param,)) β€” never f-strings or % formatting
- Paths: always use Path.resolve() and check prefix against safe base directory
- JWT: always specify algorithms=["HS256"] explicitly
- Auth: always use hmac.compare_digest() for constant-time comparison
- Hashing: use SHA-256 or stronger β€” never MD5/SHA1
- Never use eval(), exec(), or subprocess with shell=True
"""
def compress_graph(graph: dict, limit: int = 6000) -> str:
"""
Semantic compression: keep signatures and conventions, drop function bodies.
V1 used [:2000] blind truncation β€” agents couldn't see the patterns they needed.
V2 keeps what matters, drops what doesn't.
"""
slim = {
"conventions": graph.get("conventions", {}),
"components": {}
}
for name, comp in graph.get("components", {}).items():
slim["components"][name] = {
"file": comp.get("file", ""),
"language": comp.get("language", "py"),
"functions": [f["name"] if isinstance(f, dict) else f for f in comp.get("functions", [])][:20],
"imports": [i.split(".")[0] for i in comp.get("imports", [])][:15],
"uses_try_catch": comp.get("conventions", {}).get("uses_try_catch", False),
"uses_type_hints": comp.get("conventions", {}).get("uses_type_hints", False),
}
result = json.dumps(slim, indent=2)
if len(result) > limit:
for name in slim["components"]:
slim["components"][name].pop("imports", None)
result = json.dumps(slim, indent=2)[:limit]
return result
def call_llm(messages: list, timeout_s: int = 60) -> str:
"""Call LLM with exponential backoff retry on rate limit."""
for attempt in range(3):
try:
resp = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1024,
temperature=0.2,
)
return resp.choices[0].message.content.strip()
except Exception as e:
err_str = str(e).lower()
if "rate_limit" in err_str or "429" in err_str:
wait = 2 ** attempt
print(f" Rate limited. Waiting {wait}s...")
time.sleep(wait)
else:
raise
return ""
def strip_markdown(code: str) -> str:
"""Strip markdown code fences if LLM added them."""
if "```python" in code:
code = code.split("```python")[1].split("```")[0]
elif "```" in code:
parts = code.split("```")
if len(parts) >= 3:
code = parts[1]
return code.strip()
def run_episode(difficulty: str = "medium") -> dict:
"""Run one full RL episode with up to 5 improvement steps."""
# Reset environment
try:
reset_resp = requests.post(
f"{ENV_URL}/reset",
json={"difficulty": difficulty},
timeout=30,
)
reset_resp.raise_for_status()
episode = reset_resp.json()
except Exception as e:
print(f" ERROR: Could not reset env: {e}")
return {"task": "unknown", "scores": [], "final_score": 0.0, "improved": False}
sid = episode["session_id"]
scores_history = []
print(f"\n Task: {episode['task_id']} | CWEs: {episode.get('cwe_targets', [])}")
for step_num in range(5):
context_str = compress_graph(episode.get("codegraph", {}))
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"""Task: {episode['problem_statement']}
Security targets: {episode.get('cwe_targets', [])}
CODEBASE CONTEXT (follow these conventions exactly):
{context_str}
Starter code to build from:
{episode.get('starter_code', '# Write your implementation here')}
Write the complete, secure Python function now. Return ONLY the code, no markdown:"""}
]
try:
code = call_llm(messages)
except Exception as e:
print(f" Step {step_num+1}: LLM error β€” {e}")
break
code = strip_markdown(code)
if not code.strip():
print(f" Step {step_num+1}: Empty response from LLM")
break
try:
step_resp = requests.post(
f"{ENV_URL}/step",
json={
"session_id": sid,
"task_id": episode["task_id"],
"filename": f"solution_step{step_num}.py",
"code": code,
},
timeout=60,
)
step_resp.raise_for_status()
result = step_resp.json()
except Exception as e:
print(f" Step {step_num+1}: Submit error β€” {e}")
break
reward = result.get("total_reward", 0.0)
scores_history.append(reward)
done = result.get("done", False)
print(f" Step {step_num+1}: reward={reward:.4f} done={done}")
for dim, fb in result.get("feedback", {}).items():
print(f" {dim}: {fb}")
# Update context for next step
episode["codegraph"] = result.get("codegraph", {})
if done:
break
final = scores_history[-1] if scores_history else 0.0
improved = len(scores_history) > 1 and scores_history[-1] > scores_history[0]
return {
"task": episode["task_id"],
"scores": scores_history,
"final_score": final,
"improved": improved,
}
if __name__ == "__main__":
start = time.time()
results = []
print("=" * 60)
print("SecureCodeEnv V2 β€” Baseline Inference")
print(f"Model: {MODEL_NAME}")
print(f"Env: {ENV_URL}")
print("=" * 60)
for difficulty in ["easy", "medium", "hard"]:
print(f"\n{'='*20} {difficulty.upper()} {'='*20}")
r = run_episode(difficulty)
results.append(r)
elapsed = time.time() - start
print("\n" + "=" * 60)
print("FINAL RESULTS")
print("=" * 60)
for r in results:
improved_str = "↑ improved" if r["improved"] else "β†’ flat"
print(f" {r['task']}: {r['final_score']:.4f} [{improved_str}] steps={r['scores']}")
avg = sum(r["final_score"] for r in results) / len(results) if results else 0
print(f"\nMean final reward: {avg:.4f}")
print(f"Total time: {elapsed:.1f}s")
# Hackathon requirement: must complete in < 20 minutes
assert elapsed < 1200, f"Exceeded 20-minute time limit ({elapsed:.1f}s)"
print("\nβœ… Completed within time limit.")