File size: 4,153 Bytes
005d862 80b34d1 005d862 80b34d1 005d862 80b34d1 005d862 80b34d1 005d862 80b34d1 005d862 80b34d1 005d862 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | import json
import re
import math
import time
from typing import List, Any
from collections import Counter
from pydantic import ValidationError
from src.env.models import GuardrailGraph, Action, Observation, extract_and_clean_json
def calculate_entropy(text: str) -> float:
if not text:
return 0.0
probabilities = [n_x/len(text) for x, n_x in Counter(text).items()]
return -sum([p * math.log2(p) for p in probabilities])
def evaluate_node(node: dict, text: str, depth: int = 0, start_time: float = None) -> bool:
if start_time is None:
start_time = time.time()
if time.time() - start_time > 0.05: # 50ms timeout
return False
if depth > 10: # Recursion guard
return False
# Check if it's a SemanticFilter
if "filter_type" in node:
filter_type = node["filter_type"]
val = node.get("value")
if filter_type == "substring":
return str(val).lower() in text.lower()
elif filter_type == "regex_pattern":
try:
return bool(re.search(str(val), text))
except re.error:
return False
elif filter_type == "length_limit":
return len(text) > (val if isinstance(val, (int, float)) else float('inf'))
elif filter_type == "entropy_threshold":
return calculate_entropy(text) > (val if isinstance(val, (int, float)) else float('inf'))
elif filter_type == "keyword_match":
return str(val).lower() in text.lower().split()
return False
# Else it's a LogicNode
operator = node.get("operator")
children = node.get("children", [])
if operator == "AND":
if not children: return False
return all(evaluate_node(child, text, depth + 1, start_time) for child in children)
elif operator == "OR":
if not children: return False
return any(evaluate_node(child, text, depth + 1, start_time) for child in children)
elif operator == "NOT":
if not children: return False
return not evaluate_node(children[0], text, depth + 1, start_time)
return False
try:
from openenv import BaseEnvironment
except ImportError:
class BaseEnvironment:
pass
class GuardrailEnvironment(BaseEnvironment):
def __init__(self):
super().__init__()
self.state = None
def reset(self, adversarial_samples: List[str], benign_samples: List[str]) -> Observation:
self.state = Observation(
adversarial_samples=adversarial_samples,
benign_samples=benign_samples
)
return self.state
def step(self, action: Action) -> tuple[float, float, bool]:
"""
Returns (recall, fpr, syntax_error)
"""
try:
clean_json = extract_and_clean_json(action.ast_json)
parsed_ast = json.loads(clean_json)
ast_wrapper = GuardrailGraph.model_validate(parsed_ast)
except (json.JSONDecodeError, ValidationError):
return 0.0, 0.0, True
true_positives = 0
false_positives = 0
# We evaluate against raw dict to avoid recursive pydantic object overhead
root_node = ast_wrapper.model_dump().get("root", {})
try:
# Evaluate Recall (TP rate on adversarial)
adv_total = len(self.state.adversarial_samples)
for text in self.state.adversarial_samples:
if evaluate_node(root_node, text):
true_positives += 1
# Evaluate FPR (FP rate on benign)
ben_total = len(self.state.benign_samples)
for text in self.state.benign_samples:
if evaluate_node(root_node, text):
false_positives += 1
except Exception:
# Global guard against any unhandled execution exception (e.g. RecursionError escaping depth check)
return 0.0, 0.0, True
recall = true_positives / adv_total if adv_total > 0 else 0.0
fpr = false_positives / ben_total if ben_total > 0 else 0.0
return recall, fpr, False
|