Spaces:
Sleeping
Sleeping
| """ | |
| inference.py β Baseline inference script (REQUIRED by hackathon). | |
| CRITICAL requirements: | |
| - Must use OpenAI client (hackathon rule β Groq/Gemini both support it) | |
| - Must complete in < 20 minutes on 2 vCPU / 8GB RAM | |
| - Must be in project root | |
| - env vars: API_BASE_URL, MODEL_NAME, HF_TOKEN, ENV_URL | |
| Compatible with: | |
| - Groq free tier: API_BASE_URL=https://api.groq.com/openai/v1, MODEL_NAME=llama-3.3-70b-versatile | |
| - Gemini Flash: API_BASE_URL=https://generativelanguage.googleapis.com/v1beta/openai, MODEL_NAME=gemini-2.5-flash | |
| - OpenAI: API_BASE_URL=https://api.openai.com/v1, MODEL_NAME=gpt-4o-mini | |
| """ | |
| import os | |
| import json | |
| import time | |
| import requests | |
| from openai import OpenAI | |
| # ββ Config (from environment variables) ββββββββββββββββββββββββββββββββββββββ | |
| API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.groq.com/openai/v1") | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "llama-3.3-70b-versatile") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860") | |
| client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "dummy") | |
| # ββ System prompt βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = """You are a Python security engineer writing production-ready, secure Python code. | |
| When given a task, write ONLY the Python function β no explanations, no markdown fences, no comments outside the function. | |
| Your code MUST: | |
| 1. Solve the problem correctly β handle None, empty string, boundary values | |
| 2. Resist security attacks: SQL injection, path traversal, auth bypass, XSS | |
| 3. Use PARAMETERISED queries β NEVER string-format user input into SQL | |
| 4. Validate and sanitise ALL inputs before use | |
| 5. Use proper type hints on all function signatures | |
| 6. Have a docstring explaining what the function does | |
| 7. Use try/except with specific exception types (not bare except) | |
| 8. Follow the naming and error-handling conventions shown in CODEBASE CONTEXT | |
| 9. Import only well-known standard library or PyPI packages | |
| CRITICAL SECURITY RULES: | |
| - SQL: always use cursor.execute(sql, (param,)) β never f-strings or % formatting | |
| - Paths: always use Path.resolve() and check prefix against safe base directory | |
| - JWT: always specify algorithms=["HS256"] explicitly | |
| - Auth: always use hmac.compare_digest() for constant-time comparison | |
| - Hashing: use SHA-256 or stronger β never MD5/SHA1 | |
| - Never use eval(), exec(), or subprocess with shell=True | |
| """ | |
| def compress_graph(graph: dict, limit: int = 6000) -> str: | |
| """ | |
| Semantic compression: keep signatures and conventions, drop function bodies. | |
| V1 used [:2000] blind truncation β agents couldn't see the patterns they needed. | |
| V2 keeps what matters, drops what doesn't. | |
| """ | |
| slim = { | |
| "conventions": graph.get("conventions", {}), | |
| "components": {} | |
| } | |
| for name, comp in graph.get("components", {}).items(): | |
| slim["components"][name] = { | |
| "file": comp.get("file", ""), | |
| "language": comp.get("language", "py"), | |
| "functions": [f["name"] if isinstance(f, dict) else f for f in comp.get("functions", [])][:20], | |
| "imports": [i.split(".")[0] for i in comp.get("imports", [])][:15], | |
| "uses_try_catch": comp.get("conventions", {}).get("uses_try_catch", False), | |
| "uses_type_hints": comp.get("conventions", {}).get("uses_type_hints", False), | |
| } | |
| result = json.dumps(slim, indent=2) | |
| if len(result) > limit: | |
| for name in slim["components"]: | |
| slim["components"][name].pop("imports", None) | |
| result = json.dumps(slim, indent=2)[:limit] | |
| return result | |
| def call_llm(messages: list, timeout_s: int = 60) -> str: | |
| """Call LLM with exponential backoff retry on rate limit.""" | |
| for attempt in range(3): | |
| try: | |
| resp = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages, | |
| max_tokens=1024, | |
| temperature=0.2, | |
| ) | |
| return resp.choices[0].message.content.strip() | |
| except Exception as e: | |
| err_str = str(e).lower() | |
| if "rate_limit" in err_str or "429" in err_str: | |
| wait = 2 ** attempt | |
| print(f" Rate limited. Waiting {wait}s...") | |
| time.sleep(wait) | |
| else: | |
| raise | |
| return "" | |
| def strip_markdown(code: str) -> str: | |
| """Strip markdown code fences if LLM added them.""" | |
| if "```python" in code: | |
| code = code.split("```python")[1].split("```")[0] | |
| elif "```" in code: | |
| parts = code.split("```") | |
| if len(parts) >= 3: | |
| code = parts[1] | |
| return code.strip() | |
| def run_episode(difficulty: str = "medium") -> dict: | |
| """Run one full RL episode with up to 5 improvement steps.""" | |
| # Reset environment | |
| try: | |
| reset_resp = requests.post( | |
| f"{ENV_URL}/reset", | |
| json={"difficulty": difficulty}, | |
| timeout=30, | |
| ) | |
| reset_resp.raise_for_status() | |
| episode = reset_resp.json() | |
| except Exception as e: | |
| print(f" ERROR: Could not reset env: {e}") | |
| return {"task": "unknown", "scores": [], "final_score": 0.0, "improved": False} | |
| sid = episode["session_id"] | |
| scores_history = [] | |
| print(f"\n Task: {episode['task_id']} | CWEs: {episode.get('cwe_targets', [])}") | |
| for step_num in range(5): | |
| context_str = compress_graph(episode.get("codegraph", {})) | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"""Task: {episode['problem_statement']} | |
| Security targets: {episode.get('cwe_targets', [])} | |
| CODEBASE CONTEXT (follow these conventions exactly): | |
| {context_str} | |
| Starter code to build from: | |
| {episode.get('starter_code', '# Write your implementation here')} | |
| Write the complete, secure Python function now. Return ONLY the code, no markdown:"""} | |
| ] | |
| try: | |
| code = call_llm(messages) | |
| except Exception as e: | |
| print(f" Step {step_num+1}: LLM error β {e}") | |
| break | |
| code = strip_markdown(code) | |
| if not code.strip(): | |
| print(f" Step {step_num+1}: Empty response from LLM") | |
| break | |
| try: | |
| step_resp = requests.post( | |
| f"{ENV_URL}/step", | |
| json={ | |
| "session_id": sid, | |
| "task_id": episode["task_id"], | |
| "filename": f"solution_step{step_num}.py", | |
| "code": code, | |
| }, | |
| timeout=60, | |
| ) | |
| step_resp.raise_for_status() | |
| result = step_resp.json() | |
| except Exception as e: | |
| print(f" Step {step_num+1}: Submit error β {e}") | |
| break | |
| reward = result.get("total_reward", 0.0) | |
| scores_history.append(reward) | |
| done = result.get("done", False) | |
| print(f" Step {step_num+1}: reward={reward:.4f} done={done}") | |
| for dim, fb in result.get("feedback", {}).items(): | |
| print(f" {dim}: {fb}") | |
| # Update context for next step | |
| episode["codegraph"] = result.get("codegraph", {}) | |
| if done: | |
| break | |
| final = scores_history[-1] if scores_history else 0.0 | |
| improved = len(scores_history) > 1 and scores_history[-1] > scores_history[0] | |
| return { | |
| "task": episode["task_id"], | |
| "scores": scores_history, | |
| "final_score": final, | |
| "improved": improved, | |
| } | |
| if __name__ == "__main__": | |
| start = time.time() | |
| results = [] | |
| print("=" * 60) | |
| print("SecureCodeEnv V2 β Baseline Inference") | |
| print(f"Model: {MODEL_NAME}") | |
| print(f"Env: {ENV_URL}") | |
| print("=" * 60) | |
| for difficulty in ["easy", "medium", "hard"]: | |
| print(f"\n{'='*20} {difficulty.upper()} {'='*20}") | |
| r = run_episode(difficulty) | |
| results.append(r) | |
| elapsed = time.time() - start | |
| print("\n" + "=" * 60) | |
| print("FINAL RESULTS") | |
| print("=" * 60) | |
| for r in results: | |
| improved_str = "β improved" if r["improved"] else "β flat" | |
| print(f" {r['task']}: {r['final_score']:.4f} [{improved_str}] steps={r['scores']}") | |
| avg = sum(r["final_score"] for r in results) / len(results) if results else 0 | |
| print(f"\nMean final reward: {avg:.4f}") | |
| print(f"Total time: {elapsed:.1f}s") | |
| # Hackathon requirement: must complete in < 20 minutes | |
| assert elapsed < 1200, f"Exceeded 20-minute time limit ({elapsed:.1f}s)" | |
| print("\nβ Completed within time limit.") | |