File size: 4,153 Bytes
005d862
 
 
80b34d1
005d862
 
 
 
 
 
 
 
 
 
 
80b34d1
 
 
 
 
 
 
 
 
 
005d862
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80b34d1
005d862
 
80b34d1
005d862
 
80b34d1
005d862
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80b34d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
005d862
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import json
import re
import math
import time
from typing import List, Any
from collections import Counter
from pydantic import ValidationError
from src.env.models import GuardrailGraph, Action, Observation, extract_and_clean_json

def calculate_entropy(text: str) -> float:
    if not text:
        return 0.0
    probabilities = [n_x/len(text) for x, n_x in Counter(text).items()]
    return -sum([p * math.log2(p) for p in probabilities])

def evaluate_node(node: dict, text: str, depth: int = 0, start_time: float = None) -> bool:
    if start_time is None:
        start_time = time.time()
        
    if time.time() - start_time > 0.05: # 50ms timeout
        return False
        
    if depth > 10: # Recursion guard
        return False

    # Check if it's a SemanticFilter
    if "filter_type" in node:
        filter_type = node["filter_type"]
        val = node.get("value")
        if filter_type == "substring":
            return str(val).lower() in text.lower()
        elif filter_type == "regex_pattern":
            try:
                return bool(re.search(str(val), text))
            except re.error:
                return False
        elif filter_type == "length_limit":
            return len(text) > (val if isinstance(val, (int, float)) else float('inf'))
        elif filter_type == "entropy_threshold":
            return calculate_entropy(text) > (val if isinstance(val, (int, float)) else float('inf'))
        elif filter_type == "keyword_match":
            return str(val).lower() in text.lower().split()
        return False
        
    # Else it's a LogicNode
    operator = node.get("operator")
    children = node.get("children", [])
    if operator == "AND":
        if not children: return False
        return all(evaluate_node(child, text, depth + 1, start_time) for child in children)
    elif operator == "OR":
        if not children: return False
        return any(evaluate_node(child, text, depth + 1, start_time) for child in children)
    elif operator == "NOT":
        if not children: return False
        return not evaluate_node(children[0], text, depth + 1, start_time)
    return False

try:
    from openenv import BaseEnvironment
except ImportError:
    class BaseEnvironment:
        pass

class GuardrailEnvironment(BaseEnvironment):
    def __init__(self):
        super().__init__()
        self.state = None
        
    def reset(self, adversarial_samples: List[str], benign_samples: List[str]) -> Observation:
        self.state = Observation(
            adversarial_samples=adversarial_samples,
            benign_samples=benign_samples
        )
        return self.state

    def step(self, action: Action) -> tuple[float, float, bool]:
        """
        Returns (recall, fpr, syntax_error)
        """
        try:
            clean_json = extract_and_clean_json(action.ast_json)
            parsed_ast = json.loads(clean_json)
            ast_wrapper = GuardrailGraph.model_validate(parsed_ast)
        except (json.JSONDecodeError, ValidationError):
            return 0.0, 0.0, True

        true_positives = 0
        false_positives = 0

        # We evaluate against raw dict to avoid recursive pydantic object overhead
        root_node = ast_wrapper.model_dump().get("root", {})

        try:
            # Evaluate Recall (TP rate on adversarial)
            adv_total = len(self.state.adversarial_samples)
            for text in self.state.adversarial_samples:
                if evaluate_node(root_node, text):
                    true_positives += 1
                    
            # Evaluate FPR (FP rate on benign)
            ben_total = len(self.state.benign_samples)
            for text in self.state.benign_samples:
                if evaluate_node(root_node, text):
                    false_positives += 1
        except Exception:
            # Global guard against any unhandled execution exception (e.g. RecursionError escaping depth check)
            return 0.0, 0.0, True

        recall = true_positives / adv_total if adv_total > 0 else 0.0
        fpr = false_positives / ben_total if ben_total > 0 else 0.0

        return recall, fpr, False