vuln-patch-env / environment.py
rohitc1612's picture
feat: pivot to supply chain security domain (PyPI token, yaml.load, CI injection)
2e8f87e
import ast
import random
import string
from typing import Optional
from pydantic import BaseModel
# 1. OpenEnv Typed Models
class Observation(BaseModel):
step: int
vulnerability_type: str
current_code: str
linter_output: str
class Action(BaseModel):
action_type: str # e.g., "run_scan" or "submit_patch"
patched_code: Optional[str] = ""
class Reward(BaseModel):
value: float
class Info(BaseModel):
error: Optional[str] = None
# AST Checkers for Robust Grading
def uses_os_getenv(code: str) -> bool:
"""
Returns True if the code uses any safe environment-variable access pattern:
- os.getenv("KEY") -> ast.Call with Attribute attr='getenv'
- os.environ.get("KEY") -> ast.Call with chained Attribute: environ -> get
- os.environ["KEY"] -> ast.Subscript on os.environ
- bare os.environ reference -> ast.Attribute with attr='environ'
"""
try:
tree = ast.parse(code)
for node in ast.walk(tree):
# Pattern 1: os.getenv(...) or getenv(...)
if isinstance(node, ast.Call):
func = node.func
if isinstance(func, ast.Attribute) and func.attr == "getenv":
return True
if isinstance(func, ast.Name) and func.id == "getenv":
return True
# Pattern 2: os.environ.get(...) or environ.get(...)
if isinstance(func, ast.Attribute) and func.attr == "get":
if isinstance(func.value, ast.Attribute) and func.value.attr == "environ":
return True
if isinstance(func.value, ast.Name) and func.value.id == "environ":
return True
# Pattern 3: os.environ["KEY"] or environ["KEY"]
if isinstance(node, ast.Subscript):
val = node.value
if isinstance(val, ast.Attribute) and val.attr == "environ":
return True
if isinstance(val, ast.Name) and val.id == "environ":
return True
# Pattern 4: bare os.environ reference (e.g. env = os.environ)
if isinstance(node, ast.Attribute) and node.attr == "environ":
return True
if isinstance(node, ast.Name) and node.id == "environ":
return True
except SyntaxError:
pass
return False
def uses_safe_yaml(code: str) -> bool:
"""
Returns True if the agent securely parses YAML using safe_load.
"""
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Call):
func = node.func
if isinstance(func, ast.Attribute) and func.attr == "safe_load":
return True
if isinstance(func, ast.Name) and func.id == "safe_load":
return True
except SyntaxError:
pass
return False
def uses_safe_subprocess(code: str) -> bool:
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Call):
if hasattr(node.func, "attr") and node.func.attr in (
"run",
"Popen",
"call",
"check_call",
"check_output",
):
# Ensure shell=True is NOT passed
is_shell = False
for kw in node.keywords:
if (
kw.arg == "shell"
and isinstance(kw.value, ast.Constant)
and kw.value.value is True
):
is_shell = True
if not is_shell:
# Ensure first argument is a list or a variable reference
if node.args:
arg0 = node.args[0]
if isinstance(arg0, (ast.List, ast.Name)):
return True
except SyntaxError:
pass
return False
# 2. Environment Implementation
class VulnPatchEnv:
def __init__(self, task="easy"):
self.task = task
self.step_count = 0
self.done = False
self.code = ""
self.linter = ""
self.target_secret = ""
self.reset()
def reset(self) -> Observation:
random.seed(42) # Fixed seed for reproducible baseline scores
self.step_count = 0
self.done = False
self.linter = "Not run yet."
if self.task == "easy":
key_val = "sk-" + "".join(
random.choices(string.ascii_uppercase + string.digits, k=16)
)
# A real-looking build deployment script finding
self.code = f"""import os
import requests
def upload_package_to_pypi(package_data):
# Fallback legacy token (CVE-798 Supply Chain Leak)
PYPI_API_TOKEN = '{key_val}'
headers = {{"Authorization": f"Bearer {{PYPI_API_TOKEN}}"}}
response = requests.post("https://upload.pypi.org/legacy/", data=package_data, headers=headers)
return response.status_code
"""
self.target_secret = key_val
elif self.task == "medium":
# Realistic yaml config loader
self.code = """import yaml
import os
def load_project_config(config_path):
\"\"\"Reads the project.yaml configuration file for the build properties.\"\"\"
if not os.path.exists(config_path):
return {}
with open(config_path, 'r') as f:
# Vulnerable to arbitrary code execution (CVE-502)
config_data = yaml.load(f.read())
return config_data
"""
self.target_secret = ""
elif self.task == "hard":
self.code = """import os
import subprocess
from setuptools import setup, find_packages
# Example extracted from vulnerable setup.py / CI workflow
def run_pre_build_validation(repo_url):
print("Running arbitrary validations before build...")
# Vulnerable to command injection (CWE-78) via unsanitized strings
os.system(f"git clone {repo_url} /tmp/repo_check")
setup(
name="my_secure_package",
version="1.0.0",
packages=find_packages(),
)
"""
self.target_secret = ""
else:
self.code = "Unknown task."
self.target_secret = ""
return self.state()
def state(self) -> Observation:
return Observation(
step=self.step_count,
vulnerability_type=self.task,
current_code=self.code,
linter_output=self.linter,
)
def step(self, action: Action) -> tuple[Observation, Reward, bool, Info]:
self.step_count += 1
reward_val = 0.0
if action.action_type == "run_scan":
self.linter = "SECURITY SCAN (Supply Chain): Vulnerability detected. Fix hardcoded secrets, insecure deserialization (yaml), or CI command injection."
reward_val = 0.1 # Incremental progress signal
elif action.action_type == "submit_patch":
patched = action.patched_code if action.patched_code else ""
# Hybrid AST/String Grading for robustness against formatting
if self.task == "easy":
if self.target_secret and self.target_secret not in patched:
reward_val += 0.5
if uses_os_getenv(patched):
reward_val += 0.5
elif self.task == "medium":
if "yaml.load(" not in patched and "yaml.load (" not in patched:
reward_val += 0.3
if uses_safe_yaml(patched):
reward_val += 0.7
elif "yaml.safe_load" in patched or "safe_load" in patched:
reward_val += 0.4
elif self.task == "hard":
if "os.system" not in patched:
reward_val += 0.3
if uses_safe_subprocess(patched):
reward_val += 0.7
elif (
"subprocess" in patched and "[" in patched and "]" in patched
): # Fallback text format check
reward_val += 0.4
self.done = True
# Hard limit to prevent infinite loops (Penalize logic per OpenEnv spec requirement)
if self.step_count >= 5 and not self.done:
self.done = True
reward_val -= 0.2
# Clamp reward strictly within open interval (0, 1) — 0.0 and 1.0 are not allowed
reward_val = min(max(reward_val, 0.01), 0.99)
return self.state(), Reward(value=reward_val), self.done, Info()
def close(self) -> None:
"""No-op cleanup method required by the OpenEnv spec."""
pass