""" inference.py — Baseline inference script (REQUIRED by hackathon). CRITICAL requirements: - Must use OpenAI client (hackathon rule — Groq/Gemini both support it) - Must complete in < 20 minutes on 2 vCPU / 8GB RAM - Must be in project root - env vars: API_BASE_URL, MODEL_NAME, HF_TOKEN, ENV_URL Compatible with: - Groq free tier: API_BASE_URL=https://api.groq.com/openai/v1, MODEL_NAME=llama-3.3-70b-versatile - Gemini Flash: API_BASE_URL=https://generativelanguage.googleapis.com/v1beta/openai, MODEL_NAME=gemini-2.5-flash - OpenAI: API_BASE_URL=https://api.openai.com/v1, MODEL_NAME=gpt-4o-mini """ import os import json import time import requests from openai import OpenAI # ── Config (from environment variables) ────────────────────────────────────── API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.groq.com/openai/v1") MODEL_NAME = os.environ.get("MODEL_NAME", "llama-3.3-70b-versatile") HF_TOKEN = os.environ.get("HF_TOKEN", "") ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860") client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "dummy") # ── System prompt ───────────────────────────────────────────────────────────── SYSTEM_PROMPT = """You are a Python security engineer writing production-ready, secure Python code. When given a task, write ONLY the Python function — no explanations, no markdown fences, no comments outside the function. Your code MUST: 1. Solve the problem correctly — handle None, empty string, boundary values 2. Resist security attacks: SQL injection, path traversal, auth bypass, XSS 3. Use PARAMETERISED queries — NEVER string-format user input into SQL 4. Validate and sanitise ALL inputs before use 5. Use proper type hints on all function signatures 6. Have a docstring explaining what the function does 7. Use try/except with specific exception types (not bare except) 8. Follow the naming and error-handling conventions shown in CODEBASE CONTEXT 9. Import only well-known standard library or PyPI packages CRITICAL SECURITY RULES: - SQL: always use cursor.execute(sql, (param,)) — never f-strings or % formatting - Paths: always use Path.resolve() and check prefix against safe base directory - JWT: always specify algorithms=["HS256"] explicitly - Auth: always use hmac.compare_digest() for constant-time comparison - Hashing: use SHA-256 or stronger — never MD5/SHA1 - Never use eval(), exec(), or subprocess with shell=True """ def compress_graph(graph: dict, limit: int = 6000) -> str: """ Semantic compression: keep signatures and conventions, drop function bodies. V1 used [:2000] blind truncation — agents couldn't see the patterns they needed. V2 keeps what matters, drops what doesn't. """ slim = { "conventions": graph.get("conventions", {}), "components": {} } for name, comp in graph.get("components", {}).items(): slim["components"][name] = { "file": comp.get("file", ""), "language": comp.get("language", "py"), "functions": [f["name"] if isinstance(f, dict) else f for f in comp.get("functions", [])][:20], "imports": [i.split(".")[0] for i in comp.get("imports", [])][:15], "uses_try_catch": comp.get("conventions", {}).get("uses_try_catch", False), "uses_type_hints": comp.get("conventions", {}).get("uses_type_hints", False), } result = json.dumps(slim, indent=2) if len(result) > limit: for name in slim["components"]: slim["components"][name].pop("imports", None) result = json.dumps(slim, indent=2)[:limit] return result def call_llm(messages: list, timeout_s: int = 60) -> str: """Call LLM with exponential backoff retry on rate limit.""" for attempt in range(3): try: resp = client.chat.completions.create( model=MODEL_NAME, messages=messages, max_tokens=1024, temperature=0.2, ) return resp.choices[0].message.content.strip() except Exception as e: err_str = str(e).lower() if "rate_limit" in err_str or "429" in err_str: wait = 2 ** attempt print(f" Rate limited. Waiting {wait}s...") time.sleep(wait) else: raise return "" def strip_markdown(code: str) -> str: """Strip markdown code fences if LLM added them.""" if "```python" in code: code = code.split("```python")[1].split("```")[0] elif "```" in code: parts = code.split("```") if len(parts) >= 3: code = parts[1] return code.strip() def run_episode(difficulty: str = "medium") -> dict: """Run one full RL episode with up to 5 improvement steps.""" # Reset environment try: reset_resp = requests.post( f"{ENV_URL}/reset", json={"difficulty": difficulty}, timeout=30, ) reset_resp.raise_for_status() episode = reset_resp.json() except Exception as e: print(f" ERROR: Could not reset env: {e}") return {"task": "unknown", "scores": [], "final_score": 0.0, "improved": False} sid = episode["session_id"] scores_history = [] print(f"\n Task: {episode['task_id']} | CWEs: {episode.get('cwe_targets', [])}") for step_num in range(5): context_str = compress_graph(episode.get("codegraph", {})) messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"""Task: {episode['problem_statement']} Security targets: {episode.get('cwe_targets', [])} CODEBASE CONTEXT (follow these conventions exactly): {context_str} Starter code to build from: {episode.get('starter_code', '# Write your implementation here')} Write the complete, secure Python function now. Return ONLY the code, no markdown:"""} ] try: code = call_llm(messages) except Exception as e: print(f" Step {step_num+1}: LLM error — {e}") break code = strip_markdown(code) if not code.strip(): print(f" Step {step_num+1}: Empty response from LLM") break try: step_resp = requests.post( f"{ENV_URL}/step", json={ "session_id": sid, "task_id": episode["task_id"], "filename": f"solution_step{step_num}.py", "code": code, }, timeout=60, ) step_resp.raise_for_status() result = step_resp.json() except Exception as e: print(f" Step {step_num+1}: Submit error — {e}") break reward = result.get("total_reward", 0.0) scores_history.append(reward) done = result.get("done", False) print(f" Step {step_num+1}: reward={reward:.4f} done={done}") for dim, fb in result.get("feedback", {}).items(): print(f" {dim}: {fb}") # Update context for next step episode["codegraph"] = result.get("codegraph", {}) if done: break final = scores_history[-1] if scores_history else 0.0 improved = len(scores_history) > 1 and scores_history[-1] > scores_history[0] return { "task": episode["task_id"], "scores": scores_history, "final_score": final, "improved": improved, } if __name__ == "__main__": start = time.time() results = [] print("=" * 60) print("SecureCodeEnv V2 — Baseline Inference") print(f"Model: {MODEL_NAME}") print(f"Env: {ENV_URL}") print("=" * 60) for difficulty in ["easy", "medium", "hard"]: print(f"\n{'='*20} {difficulty.upper()} {'='*20}") r = run_episode(difficulty) results.append(r) elapsed = time.time() - start print("\n" + "=" * 60) print("FINAL RESULTS") print("=" * 60) for r in results: improved_str = "↑ improved" if r["improved"] else "→ flat" print(f" {r['task']}: {r['final_score']:.4f} [{improved_str}] steps={r['scores']}") avg = sum(r["final_score"] for r in results) / len(results) if results else 0 print(f"\nMean final reward: {avg:.4f}") print(f"Total time: {elapsed:.1f}s") # Hackathon requirement: must complete in < 20 minutes assert elapsed < 1200, f"Exceeded 20-minute time limit ({elapsed:.1f}s)" print("\n✅ Completed within time limit.")