Rithwik Ravi commited on
Commit ·
005d862
1
Parent(s): f421da5
fix: anchor env/ in gitignore to prevent excluding src/env package
Browse files- .gitignore +1 -1
- src/env/__init__.py +0 -0
- src/env/guardrail.py +99 -0
- src/env/models.py +46 -0
- src/env/reward.py +14 -0
.gitignore
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
# Environments
|
| 2 |
.venv/
|
| 3 |
venv/
|
| 4 |
-
env/
|
| 5 |
|
| 6 |
# Python Cache
|
| 7 |
__pycache__/
|
|
|
|
| 1 |
# Environments
|
| 2 |
.venv/
|
| 3 |
venv/
|
| 4 |
+
/env/
|
| 5 |
|
| 6 |
# Python Cache
|
| 7 |
__pycache__/
|
src/env/__init__.py
ADDED
|
File without changes
|
src/env/guardrail.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import math
|
| 4 |
+
from typing import List, Any
|
| 5 |
+
from collections import Counter
|
| 6 |
+
from pydantic import ValidationError
|
| 7 |
+
from src.env.models import GuardrailGraph, Action, Observation, extract_and_clean_json
|
| 8 |
+
|
| 9 |
+
def calculate_entropy(text: str) -> float:
|
| 10 |
+
if not text:
|
| 11 |
+
return 0.0
|
| 12 |
+
probabilities = [n_x/len(text) for x, n_x in Counter(text).items()]
|
| 13 |
+
return -sum([p * math.log2(p) for p in probabilities])
|
| 14 |
+
|
| 15 |
+
def evaluate_node(node: dict, text: str) -> bool:
|
| 16 |
+
# Check if it's a SemanticFilter
|
| 17 |
+
if "filter_type" in node:
|
| 18 |
+
filter_type = node["filter_type"]
|
| 19 |
+
val = node.get("value")
|
| 20 |
+
if filter_type == "substring":
|
| 21 |
+
return str(val).lower() in text.lower()
|
| 22 |
+
elif filter_type == "regex_pattern":
|
| 23 |
+
try:
|
| 24 |
+
return bool(re.search(str(val), text))
|
| 25 |
+
except re.error:
|
| 26 |
+
return False
|
| 27 |
+
elif filter_type == "length_limit":
|
| 28 |
+
return len(text) > (val if isinstance(val, (int, float)) else float('inf'))
|
| 29 |
+
elif filter_type == "entropy_threshold":
|
| 30 |
+
return calculate_entropy(text) > (val if isinstance(val, (int, float)) else float('inf'))
|
| 31 |
+
elif filter_type == "keyword_match":
|
| 32 |
+
return str(val).lower() in text.lower().split()
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
# Else it's a LogicNode
|
| 36 |
+
operator = node.get("operator")
|
| 37 |
+
children = node.get("children", [])
|
| 38 |
+
if operator == "AND":
|
| 39 |
+
if not children: return False
|
| 40 |
+
return all(evaluate_node(child, text) for child in children)
|
| 41 |
+
elif operator == "OR":
|
| 42 |
+
if not children: return False
|
| 43 |
+
return any(evaluate_node(child, text) for child in children)
|
| 44 |
+
elif operator == "NOT":
|
| 45 |
+
if not children: return False
|
| 46 |
+
return not evaluate_node(children[0], text)
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
from openenv import BaseEnvironment
|
| 51 |
+
except ImportError:
|
| 52 |
+
class BaseEnvironment:
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
class GuardrailEnvironment(BaseEnvironment):
|
| 56 |
+
def __init__(self):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.state = None
|
| 59 |
+
|
| 60 |
+
def reset(self, adversarial_samples: List[str], benign_samples: List[str]) -> Observation:
|
| 61 |
+
self.state = Observation(
|
| 62 |
+
adversarial_samples=adversarial_samples,
|
| 63 |
+
benign_samples=benign_samples
|
| 64 |
+
)
|
| 65 |
+
return self.state
|
| 66 |
+
|
| 67 |
+
def step(self, action: Action) -> tuple[float, float, bool]:
|
| 68 |
+
"""
|
| 69 |
+
Returns (recall, fpr, syntax_error)
|
| 70 |
+
"""
|
| 71 |
+
try:
|
| 72 |
+
clean_json = extract_and_clean_json(action.ast_json)
|
| 73 |
+
parsed_ast = json.loads(clean_json)
|
| 74 |
+
ast_wrapper = GuardrailGraph.model_validate(parsed_ast)
|
| 75 |
+
except (json.JSONDecodeError, ValidationError):
|
| 76 |
+
return 0.0, 0.0, True
|
| 77 |
+
|
| 78 |
+
true_positives = 0
|
| 79 |
+
false_positives = 0
|
| 80 |
+
|
| 81 |
+
# We evaluate against raw dict to avoid recursive pydantic object overhead
|
| 82 |
+
root_node = ast_wrapper.model_dump().get("root", {})
|
| 83 |
+
|
| 84 |
+
# Evaluate Recall (TP rate on adversarial)
|
| 85 |
+
adv_total = len(self.state.adversarial_samples)
|
| 86 |
+
for text in self.state.adversarial_samples:
|
| 87 |
+
if evaluate_node(root_node, text):
|
| 88 |
+
true_positives += 1
|
| 89 |
+
|
| 90 |
+
# Evaluate FPR (FP rate on benign)
|
| 91 |
+
ben_total = len(self.state.benign_samples)
|
| 92 |
+
for text in self.state.benign_samples:
|
| 93 |
+
if evaluate_node(root_node, text):
|
| 94 |
+
false_positives += 1
|
| 95 |
+
|
| 96 |
+
recall = true_positives / adv_total if adv_total > 0 else 0.0
|
| 97 |
+
fpr = false_positives / ben_total if ben_total > 0 else 0.0
|
| 98 |
+
|
| 99 |
+
return recall, fpr, False
|
src/env/models.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union, Literal, Optional
|
| 2 |
+
from pydantic import BaseModel, ConfigDict
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
class SemanticFilter(BaseModel):
|
| 7 |
+
filter_type: Literal["substring", "regex_pattern", "length_limit", "entropy_threshold", "keyword_match"]
|
| 8 |
+
value: Union[str, int, float]
|
| 9 |
+
|
| 10 |
+
class LogicNode(BaseModel):
|
| 11 |
+
operator: Literal["AND", "OR", "NOT"]
|
| 12 |
+
children: List[Union["LogicNode", SemanticFilter]]
|
| 13 |
+
|
| 14 |
+
class GuardrailGraph(BaseModel):
|
| 15 |
+
graph_id: str
|
| 16 |
+
description: str
|
| 17 |
+
root: LogicNode
|
| 18 |
+
|
| 19 |
+
LogicNode.model_rebuild()
|
| 20 |
+
|
| 21 |
+
class Observation(BaseModel):
|
| 22 |
+
adversarial_samples: List[str]
|
| 23 |
+
benign_samples: List[str]
|
| 24 |
+
|
| 25 |
+
class Action(BaseModel):
|
| 26 |
+
ast_json: str # The model outputs a JSON string representing the GuardrailGraph
|
| 27 |
+
baseline_ast_json: Optional[str] = None
|
| 28 |
+
|
| 29 |
+
class StepResult(BaseModel):
|
| 30 |
+
observation: Observation
|
| 31 |
+
reward: float
|
| 32 |
+
done: bool
|
| 33 |
+
info: dict
|
| 34 |
+
|
| 35 |
+
def extract_and_clean_json(text: str) -> str:
|
| 36 |
+
# Extract JSON blocks from markdown explicitly and strip trailing commas
|
| 37 |
+
text = text.strip()
|
| 38 |
+
match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
|
| 39 |
+
if match:
|
| 40 |
+
text = match.group(1)
|
| 41 |
+
|
| 42 |
+
# Replace trailing commas before closing braces/brackets
|
| 43 |
+
text = re.sub(r',\s*}', '}', text)
|
| 44 |
+
text = re.sub(r',\s*]', ']', text)
|
| 45 |
+
|
| 46 |
+
return text.strip()
|
src/env/reward.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
class LogBarrierReward:
|
| 4 |
+
def calculate(self, recall: float, fpr: float, syntax_error: bool = False) -> float:
|
| 5 |
+
"""
|
| 6 |
+
Recall = True Positive Rate on adversarial
|
| 7 |
+
FPR = False Positive Rate on benign
|
| 8 |
+
Reward = (1.0 * Recall) - (2.0 * math.log1p(FPR))
|
| 9 |
+
"""
|
| 10 |
+
if syntax_error:
|
| 11 |
+
return -10.0
|
| 12 |
+
|
| 13 |
+
reward = (1.0 * recall) - (2.0 * math.log1p(fpr))
|
| 14 |
+
return float(reward)
|