Spaces:
Sleeping
Sleeping
| import ast | |
| import random | |
| import string | |
| from typing import Optional | |
| from pydantic import BaseModel | |
| # 1. OpenEnv Typed Models | |
| class Observation(BaseModel): | |
| step: int | |
| vulnerability_type: str | |
| current_code: str | |
| linter_output: str | |
| class Action(BaseModel): | |
| action_type: str # e.g., "run_scan" or "submit_patch" | |
| patched_code: Optional[str] = "" | |
| class Reward(BaseModel): | |
| value: float | |
| class Info(BaseModel): | |
| error: Optional[str] = None | |
| # AST Checkers for Robust Grading | |
| def uses_os_getenv(code: str) -> bool: | |
| """ | |
| Returns True if the code uses any safe environment-variable access pattern: | |
| - os.getenv("KEY") -> ast.Call with Attribute attr='getenv' | |
| - os.environ.get("KEY") -> ast.Call with chained Attribute: environ -> get | |
| - os.environ["KEY"] -> ast.Subscript on os.environ | |
| - bare os.environ reference -> ast.Attribute with attr='environ' | |
| """ | |
| try: | |
| tree = ast.parse(code) | |
| for node in ast.walk(tree): | |
| # Pattern 1: os.getenv(...) or getenv(...) | |
| if isinstance(node, ast.Call): | |
| func = node.func | |
| if isinstance(func, ast.Attribute) and func.attr == "getenv": | |
| return True | |
| if isinstance(func, ast.Name) and func.id == "getenv": | |
| return True | |
| # Pattern 2: os.environ.get(...) or environ.get(...) | |
| if isinstance(func, ast.Attribute) and func.attr == "get": | |
| if isinstance(func.value, ast.Attribute) and func.value.attr == "environ": | |
| return True | |
| if isinstance(func.value, ast.Name) and func.value.id == "environ": | |
| return True | |
| # Pattern 3: os.environ["KEY"] or environ["KEY"] | |
| if isinstance(node, ast.Subscript): | |
| val = node.value | |
| if isinstance(val, ast.Attribute) and val.attr == "environ": | |
| return True | |
| if isinstance(val, ast.Name) and val.id == "environ": | |
| return True | |
| # Pattern 4: bare os.environ reference (e.g. env = os.environ) | |
| if isinstance(node, ast.Attribute) and node.attr == "environ": | |
| return True | |
| if isinstance(node, ast.Name) and node.id == "environ": | |
| return True | |
| except SyntaxError: | |
| pass | |
| return False | |
| def uses_safe_yaml(code: str) -> bool: | |
| """ | |
| Returns True if the agent securely parses YAML using safe_load. | |
| """ | |
| try: | |
| tree = ast.parse(code) | |
| for node in ast.walk(tree): | |
| if isinstance(node, ast.Call): | |
| func = node.func | |
| if isinstance(func, ast.Attribute) and func.attr == "safe_load": | |
| return True | |
| if isinstance(func, ast.Name) and func.id == "safe_load": | |
| return True | |
| except SyntaxError: | |
| pass | |
| return False | |
| def uses_safe_subprocess(code: str) -> bool: | |
| try: | |
| tree = ast.parse(code) | |
| for node in ast.walk(tree): | |
| if isinstance(node, ast.Call): | |
| if hasattr(node.func, "attr") and node.func.attr in ( | |
| "run", | |
| "Popen", | |
| "call", | |
| "check_call", | |
| "check_output", | |
| ): | |
| # Ensure shell=True is NOT passed | |
| is_shell = False | |
| for kw in node.keywords: | |
| if ( | |
| kw.arg == "shell" | |
| and isinstance(kw.value, ast.Constant) | |
| and kw.value.value is True | |
| ): | |
| is_shell = True | |
| if not is_shell: | |
| # Ensure first argument is a list or a variable reference | |
| if node.args: | |
| arg0 = node.args[0] | |
| if isinstance(arg0, (ast.List, ast.Name)): | |
| return True | |
| except SyntaxError: | |
| pass | |
| return False | |
| # 2. Environment Implementation | |
| class VulnPatchEnv: | |
| def __init__(self, task="easy"): | |
| self.task = task | |
| self.step_count = 0 | |
| self.done = False | |
| self.code = "" | |
| self.linter = "" | |
| self.target_secret = "" | |
| self.reset() | |
| def reset(self) -> Observation: | |
| random.seed(42) # Fixed seed for reproducible baseline scores | |
| self.step_count = 0 | |
| self.done = False | |
| self.linter = "Not run yet." | |
| if self.task == "easy": | |
| key_val = "sk-" + "".join( | |
| random.choices(string.ascii_uppercase + string.digits, k=16) | |
| ) | |
| # A real-looking build deployment script finding | |
| self.code = f"""import os | |
| import requests | |
| def upload_package_to_pypi(package_data): | |
| # Fallback legacy token (CVE-798 Supply Chain Leak) | |
| PYPI_API_TOKEN = '{key_val}' | |
| headers = {{"Authorization": f"Bearer {{PYPI_API_TOKEN}}"}} | |
| response = requests.post("https://upload.pypi.org/legacy/", data=package_data, headers=headers) | |
| return response.status_code | |
| """ | |
| self.target_secret = key_val | |
| elif self.task == "medium": | |
| # Realistic yaml config loader | |
| self.code = """import yaml | |
| import os | |
| def load_project_config(config_path): | |
| \"\"\"Reads the project.yaml configuration file for the build properties.\"\"\" | |
| if not os.path.exists(config_path): | |
| return {} | |
| with open(config_path, 'r') as f: | |
| # Vulnerable to arbitrary code execution (CVE-502) | |
| config_data = yaml.load(f.read()) | |
| return config_data | |
| """ | |
| self.target_secret = "" | |
| elif self.task == "hard": | |
| self.code = """import os | |
| import subprocess | |
| from setuptools import setup, find_packages | |
| # Example extracted from vulnerable setup.py / CI workflow | |
| def run_pre_build_validation(repo_url): | |
| print("Running arbitrary validations before build...") | |
| # Vulnerable to command injection (CWE-78) via unsanitized strings | |
| os.system(f"git clone {repo_url} /tmp/repo_check") | |
| setup( | |
| name="my_secure_package", | |
| version="1.0.0", | |
| packages=find_packages(), | |
| ) | |
| """ | |
| self.target_secret = "" | |
| else: | |
| self.code = "Unknown task." | |
| self.target_secret = "" | |
| return self.state() | |
| def state(self) -> Observation: | |
| return Observation( | |
| step=self.step_count, | |
| vulnerability_type=self.task, | |
| current_code=self.code, | |
| linter_output=self.linter, | |
| ) | |
| def step(self, action: Action) -> tuple[Observation, Reward, bool, Info]: | |
| self.step_count += 1 | |
| reward_val = 0.0 | |
| if action.action_type == "run_scan": | |
| self.linter = "SECURITY SCAN (Supply Chain): Vulnerability detected. Fix hardcoded secrets, insecure deserialization (yaml), or CI command injection." | |
| reward_val = 0.1 # Incremental progress signal | |
| elif action.action_type == "submit_patch": | |
| patched = action.patched_code if action.patched_code else "" | |
| # Hybrid AST/String Grading for robustness against formatting | |
| if self.task == "easy": | |
| if self.target_secret and self.target_secret not in patched: | |
| reward_val += 0.5 | |
| if uses_os_getenv(patched): | |
| reward_val += 0.5 | |
| elif self.task == "medium": | |
| if "yaml.load(" not in patched and "yaml.load (" not in patched: | |
| reward_val += 0.3 | |
| if uses_safe_yaml(patched): | |
| reward_val += 0.7 | |
| elif "yaml.safe_load" in patched or "safe_load" in patched: | |
| reward_val += 0.4 | |
| elif self.task == "hard": | |
| if "os.system" not in patched: | |
| reward_val += 0.3 | |
| if uses_safe_subprocess(patched): | |
| reward_val += 0.7 | |
| elif ( | |
| "subprocess" in patched and "[" in patched and "]" in patched | |
| ): # Fallback text format check | |
| reward_val += 0.4 | |
| self.done = True | |
| # Hard limit to prevent infinite loops (Penalize logic per OpenEnv spec requirement) | |
| if self.step_count >= 5 and not self.done: | |
| self.done = True | |
| reward_val -= 0.2 | |
| # Clamp reward strictly within open interval (0, 1) — 0.0 and 1.0 are not allowed | |
| reward_val = min(max(reward_val, 0.01), 0.99) | |
| return self.state(), Reward(value=reward_val), self.done, Info() | |
| def close(self) -> None: | |
| """No-op cleanup method required by the OpenEnv spec.""" | |
| pass | |