Rithwik Ravi commited on
Commit
005d862
·
1 Parent(s): f421da5

fix: anchor env/ in gitignore to prevent excluding src/env package

Browse files
.gitignore CHANGED
@@ -1,7 +1,7 @@
1
  # Environments
2
  .venv/
3
  venv/
4
- env/
5
 
6
  # Python Cache
7
  __pycache__/
 
1
  # Environments
2
  .venv/
3
  venv/
4
+ /env/
5
 
6
  # Python Cache
7
  __pycache__/
src/env/__init__.py ADDED
File without changes
src/env/guardrail.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import math
4
+ from typing import List, Any
5
+ from collections import Counter
6
+ from pydantic import ValidationError
7
+ from src.env.models import GuardrailGraph, Action, Observation, extract_and_clean_json
8
+
9
+ def calculate_entropy(text: str) -> float:
10
+ if not text:
11
+ return 0.0
12
+ probabilities = [n_x/len(text) for x, n_x in Counter(text).items()]
13
+ return -sum([p * math.log2(p) for p in probabilities])
14
+
15
+ def evaluate_node(node: dict, text: str) -> bool:
16
+ # Check if it's a SemanticFilter
17
+ if "filter_type" in node:
18
+ filter_type = node["filter_type"]
19
+ val = node.get("value")
20
+ if filter_type == "substring":
21
+ return str(val).lower() in text.lower()
22
+ elif filter_type == "regex_pattern":
23
+ try:
24
+ return bool(re.search(str(val), text))
25
+ except re.error:
26
+ return False
27
+ elif filter_type == "length_limit":
28
+ return len(text) > (val if isinstance(val, (int, float)) else float('inf'))
29
+ elif filter_type == "entropy_threshold":
30
+ return calculate_entropy(text) > (val if isinstance(val, (int, float)) else float('inf'))
31
+ elif filter_type == "keyword_match":
32
+ return str(val).lower() in text.lower().split()
33
+ return False
34
+
35
+ # Else it's a LogicNode
36
+ operator = node.get("operator")
37
+ children = node.get("children", [])
38
+ if operator == "AND":
39
+ if not children: return False
40
+ return all(evaluate_node(child, text) for child in children)
41
+ elif operator == "OR":
42
+ if not children: return False
43
+ return any(evaluate_node(child, text) for child in children)
44
+ elif operator == "NOT":
45
+ if not children: return False
46
+ return not evaluate_node(children[0], text)
47
+ return False
48
+
49
+ try:
50
+ from openenv import BaseEnvironment
51
+ except ImportError:
52
+ class BaseEnvironment:
53
+ pass
54
+
55
+ class GuardrailEnvironment(BaseEnvironment):
56
+ def __init__(self):
57
+ super().__init__()
58
+ self.state = None
59
+
60
+ def reset(self, adversarial_samples: List[str], benign_samples: List[str]) -> Observation:
61
+ self.state = Observation(
62
+ adversarial_samples=adversarial_samples,
63
+ benign_samples=benign_samples
64
+ )
65
+ return self.state
66
+
67
+ def step(self, action: Action) -> tuple[float, float, bool]:
68
+ """
69
+ Returns (recall, fpr, syntax_error)
70
+ """
71
+ try:
72
+ clean_json = extract_and_clean_json(action.ast_json)
73
+ parsed_ast = json.loads(clean_json)
74
+ ast_wrapper = GuardrailGraph.model_validate(parsed_ast)
75
+ except (json.JSONDecodeError, ValidationError):
76
+ return 0.0, 0.0, True
77
+
78
+ true_positives = 0
79
+ false_positives = 0
80
+
81
+ # We evaluate against raw dict to avoid recursive pydantic object overhead
82
+ root_node = ast_wrapper.model_dump().get("root", {})
83
+
84
+ # Evaluate Recall (TP rate on adversarial)
85
+ adv_total = len(self.state.adversarial_samples)
86
+ for text in self.state.adversarial_samples:
87
+ if evaluate_node(root_node, text):
88
+ true_positives += 1
89
+
90
+ # Evaluate FPR (FP rate on benign)
91
+ ben_total = len(self.state.benign_samples)
92
+ for text in self.state.benign_samples:
93
+ if evaluate_node(root_node, text):
94
+ false_positives += 1
95
+
96
+ recall = true_positives / adv_total if adv_total > 0 else 0.0
97
+ fpr = false_positives / ben_total if ben_total > 0 else 0.0
98
+
99
+ return recall, fpr, False
src/env/models.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union, Literal, Optional
2
+ from pydantic import BaseModel, ConfigDict
3
+ import json
4
+ import re
5
+
6
+ class SemanticFilter(BaseModel):
7
+ filter_type: Literal["substring", "regex_pattern", "length_limit", "entropy_threshold", "keyword_match"]
8
+ value: Union[str, int, float]
9
+
10
+ class LogicNode(BaseModel):
11
+ operator: Literal["AND", "OR", "NOT"]
12
+ children: List[Union["LogicNode", SemanticFilter]]
13
+
14
+ class GuardrailGraph(BaseModel):
15
+ graph_id: str
16
+ description: str
17
+ root: LogicNode
18
+
19
+ LogicNode.model_rebuild()
20
+
21
+ class Observation(BaseModel):
22
+ adversarial_samples: List[str]
23
+ benign_samples: List[str]
24
+
25
+ class Action(BaseModel):
26
+ ast_json: str # The model outputs a JSON string representing the GuardrailGraph
27
+ baseline_ast_json: Optional[str] = None
28
+
29
+ class StepResult(BaseModel):
30
+ observation: Observation
31
+ reward: float
32
+ done: bool
33
+ info: dict
34
+
35
+ def extract_and_clean_json(text: str) -> str:
36
+ # Extract JSON blocks from markdown explicitly and strip trailing commas
37
+ text = text.strip()
38
+ match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
39
+ if match:
40
+ text = match.group(1)
41
+
42
+ # Replace trailing commas before closing braces/brackets
43
+ text = re.sub(r',\s*}', '}', text)
44
+ text = re.sub(r',\s*]', ']', text)
45
+
46
+ return text.strip()
src/env/reward.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ class LogBarrierReward:
4
+ def calculate(self, recall: float, fpr: float, syntax_error: bool = False) -> float:
5
+ """
6
+ Recall = True Positive Rate on adversarial
7
+ FPR = False Positive Rate on benign
8
+ Reward = (1.0 * Recall) - (2.0 * math.log1p(FPR))
9
+ """
10
+ if syntax_error:
11
+ return -10.0
12
+
13
+ reward = (1.0 * recall) - (2.0 * math.log1p(fpr))
14
+ return float(reward)