from .models import Observation, State, Action from .injection.injector import HallucinationInjector from .reward.rewards import reward_task_accuracy, reward_detection_f1, reward_format_compliance, reward_antipropagation import json import uuid import random class InjectionParams: def __init__(self, injection_type, confidence_level, magnitude): self.injection_type = injection_type self.confidence_level = confidence_level self.magnitude = magnitude class PropagationShieldEnvironment: def __init__(self): self.injector = HallucinationInjector() self.state = None self.timeout_seconds = 30 # Load your fixtures with open('env/tasks/data/fact_retrieval_tasks.json', 'r') as f: self.tasks = json.load(f) def reset(self, config: dict = {}) -> tuple[Observation, State]: difficulty = config.get('difficulty', 'EASY') task = random.choice(self.tasks) # Simplified sampler # Mock Adversary: Decide parameters based on difficulty params = [] if difficulty == 'EASY': params.append(InjectionParams("FACTUAL_FABRICATION", "low", 2.0)) poisoned_ctx, injected_idx = self.injector.inject(task["clean_context"], params) self.state = State( task_id=task["task_id"], injected_indices=injected_idx, injection_types=["FACTUAL_FABRICATION"], ground_truth=task["source_data"], difficulty=difficulty, episode_id=str(uuid.uuid4()) ) obs = Observation( query=task["query"], context=poisoned_ctx, step=0, difficulty=difficulty, n_passages=len(poisoned_ctx) ) return obs, self.state def step(self, action: Action) -> tuple[Observation, dict, bool, dict]: # Parse the agent's action agent_flags = [] try: parsed = json.loads(action.raw_text.replace('```json', '').replace('```', '').strip()) agent_flags = [f.get('passage_index', -1) for f in parsed.get('suspicion_flags', [])] agent_answer = parsed.get('answer', '') except: agent_answer = "" # Compute the 4 rewards r_task = reward_task_accuracy(agent_answer, self.state.ground_truth) r_detect = reward_detection_f1(agent_flags, self.state.injected_indices) r_format = reward_format_compliance(action.raw_text, len(self.state.injected_indices)) r_anti = reward_antipropagation(r_task, r_detect) rewards = { "task": r_task, "detection": r_detect, "format": r_format, "antiprop": r_anti, "total": r_task + r_detect + r_format + r_anti } done = True info = { 'rewards': rewards, 'injected_indices': self.state.injected_indices, 'episode_id': self.state.episode_id } obs = Observation(query='', context=[], step=1, difficulty=self.state.difficulty, n_passages=0) return obs, rewards, done, info