File size: 3,292 Bytes
dada368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from .models import Observation, State, Action
from .injection.injector import HallucinationInjector
from .reward.rewards import reward_task_accuracy, reward_detection_f1, reward_format_compliance, reward_antipropagation
import json
import uuid
import random

class InjectionParams:
    def __init__(self, injection_type, confidence_level, magnitude):
        self.injection_type = injection_type
        self.confidence_level = confidence_level
        self.magnitude = magnitude

class PropagationShieldEnvironment:
    def __init__(self):
        self.injector = HallucinationInjector()
        self.state = None
        self.timeout_seconds = 30
        
        # Load your fixtures
        with open('env/tasks/data/fact_retrieval_tasks.json', 'r') as f:
            self.tasks = json.load(f)

    def reset(self, config: dict = {}) -> tuple[Observation, State]:
        difficulty = config.get('difficulty', 'EASY')
        task = random.choice(self.tasks) # Simplified sampler
        
        # Mock Adversary: Decide parameters based on difficulty
        params = []
        if difficulty == 'EASY':
            params.append(InjectionParams("FACTUAL_FABRICATION", "low", 2.0))
            
        poisoned_ctx, injected_idx = self.injector.inject(task["clean_context"], params)
        
        self.state = State(
            task_id=task["task_id"],
            injected_indices=injected_idx,
            injection_types=["FACTUAL_FABRICATION"],
            ground_truth=task["source_data"],
            difficulty=difficulty,
            episode_id=str(uuid.uuid4())
        )
        
        obs = Observation(
            query=task["query"],
            context=poisoned_ctx,
            step=0,
            difficulty=difficulty,
            n_passages=len(poisoned_ctx)
        )
        return obs, self.state

    def step(self, action: Action) -> tuple[Observation, dict, bool, dict]:
        # Parse the agent's action
        agent_flags = []
        try:
            parsed = json.loads(action.raw_text.replace('```json', '').replace('```', '').strip())
            agent_flags = [f.get('passage_index', -1) for f in parsed.get('suspicion_flags', [])]
            agent_answer = parsed.get('answer', '')
        except:
            agent_answer = ""
            
        # Compute the 4 rewards
        r_task = reward_task_accuracy(agent_answer, self.state.ground_truth)
        r_detect = reward_detection_f1(agent_flags, self.state.injected_indices)
        r_format = reward_format_compliance(action.raw_text, len(self.state.injected_indices))
        r_anti = reward_antipropagation(r_task, r_detect)
        
        rewards = {
            "task": r_task,
            "detection": r_detect,
            "format": r_format,
            "antiprop": r_anti,
            "total": r_task + r_detect + r_format + r_anti
        }
        
        done = True
        info = {
            'rewards': rewards,
            'injected_indices': self.state.injected_indices,
            'episode_id': self.state.episode_id
        }
        
        obs = Observation(query='', context=[], step=1, difficulty=self.state.difficulty, n_passages=0)
        return obs, rewards, done, info