Spaces:
Sleeping
Sleeping
| from .models import Observation, State, Action | |
| from .injection.injector import HallucinationInjector | |
| from .reward.rewards import reward_task_accuracy, reward_detection_f1, reward_format_compliance, reward_antipropagation | |
| import json | |
| import uuid | |
| import random | |
| class InjectionParams: | |
| def __init__(self, injection_type, confidence_level, magnitude): | |
| self.injection_type = injection_type | |
| self.confidence_level = confidence_level | |
| self.magnitude = magnitude | |
| class PropagationShieldEnvironment: | |
| def __init__(self): | |
| self.injector = HallucinationInjector() | |
| self.state = None | |
| self.timeout_seconds = 30 | |
| # Load your fixtures | |
| with open('env/tasks/data/fact_retrieval_tasks.json', 'r') as f: | |
| self.tasks = json.load(f) | |
| def reset(self, config: dict = {}) -> tuple[Observation, State]: | |
| difficulty = config.get('difficulty', 'EASY') | |
| task = random.choice(self.tasks) # Simplified sampler | |
| # Mock Adversary: Decide parameters based on difficulty | |
| params = [] | |
| if difficulty == 'EASY': | |
| params.append(InjectionParams("FACTUAL_FABRICATION", "low", 2.0)) | |
| poisoned_ctx, injected_idx = self.injector.inject(task["clean_context"], params) | |
| self.state = State( | |
| task_id=task["task_id"], | |
| injected_indices=injected_idx, | |
| injection_types=["FACTUAL_FABRICATION"], | |
| ground_truth=task["source_data"], | |
| difficulty=difficulty, | |
| episode_id=str(uuid.uuid4()) | |
| ) | |
| obs = Observation( | |
| query=task["query"], | |
| context=poisoned_ctx, | |
| step=0, | |
| difficulty=difficulty, | |
| n_passages=len(poisoned_ctx) | |
| ) | |
| return obs, self.state | |
| def step(self, action: Action) -> tuple[Observation, dict, bool, dict]: | |
| # Parse the agent's action | |
| agent_flags = [] | |
| try: | |
| parsed = json.loads(action.raw_text.replace('```json', '').replace('```', '').strip()) | |
| agent_flags = [f.get('passage_index', -1) for f in parsed.get('suspicion_flags', [])] | |
| agent_answer = parsed.get('answer', '') | |
| except: | |
| agent_answer = "" | |
| # Compute the 4 rewards | |
| r_task = reward_task_accuracy(agent_answer, self.state.ground_truth) | |
| r_detect = reward_detection_f1(agent_flags, self.state.injected_indices) | |
| r_format = reward_format_compliance(action.raw_text, len(self.state.injected_indices)) | |
| r_anti = reward_antipropagation(r_task, r_detect) | |
| rewards = { | |
| "task": r_task, | |
| "detection": r_detect, | |
| "format": r_format, | |
| "antiprop": r_anti, | |
| "total": r_task + r_detect + r_format + r_anti | |
| } | |
| done = True | |
| info = { | |
| 'rewards': rewards, | |
| 'injected_indices': self.state.injected_indices, | |
| 'episode_id': self.state.episode_id | |
| } | |
| obs = Observation(query='', context=[], step=1, difficulty=self.state.difficulty, n_passages=0) | |
| return obs, rewards, done, info |