propagationshield-env / env /environment.py
pragunk's picture
Upload 14 files
dada368 verified
from .models import Observation, State, Action
from .injection.injector import HallucinationInjector
from .reward.rewards import reward_task_accuracy, reward_detection_f1, reward_format_compliance, reward_antipropagation
import json
import uuid
import random
class InjectionParams:
def __init__(self, injection_type, confidence_level, magnitude):
self.injection_type = injection_type
self.confidence_level = confidence_level
self.magnitude = magnitude
class PropagationShieldEnvironment:
def __init__(self):
self.injector = HallucinationInjector()
self.state = None
self.timeout_seconds = 30
# Load your fixtures
with open('env/tasks/data/fact_retrieval_tasks.json', 'r') as f:
self.tasks = json.load(f)
def reset(self, config: dict = {}) -> tuple[Observation, State]:
difficulty = config.get('difficulty', 'EASY')
task = random.choice(self.tasks) # Simplified sampler
# Mock Adversary: Decide parameters based on difficulty
params = []
if difficulty == 'EASY':
params.append(InjectionParams("FACTUAL_FABRICATION", "low", 2.0))
poisoned_ctx, injected_idx = self.injector.inject(task["clean_context"], params)
self.state = State(
task_id=task["task_id"],
injected_indices=injected_idx,
injection_types=["FACTUAL_FABRICATION"],
ground_truth=task["source_data"],
difficulty=difficulty,
episode_id=str(uuid.uuid4())
)
obs = Observation(
query=task["query"],
context=poisoned_ctx,
step=0,
difficulty=difficulty,
n_passages=len(poisoned_ctx)
)
return obs, self.state
def step(self, action: Action) -> tuple[Observation, dict, bool, dict]:
# Parse the agent's action
agent_flags = []
try:
parsed = json.loads(action.raw_text.replace('```json', '').replace('```', '').strip())
agent_flags = [f.get('passage_index', -1) for f in parsed.get('suspicion_flags', [])]
agent_answer = parsed.get('answer', '')
except:
agent_answer = ""
# Compute the 4 rewards
r_task = reward_task_accuracy(agent_answer, self.state.ground_truth)
r_detect = reward_detection_f1(agent_flags, self.state.injected_indices)
r_format = reward_format_compliance(action.raw_text, len(self.state.injected_indices))
r_anti = reward_antipropagation(r_task, r_detect)
rewards = {
"task": r_task,
"detection": r_detect,
"format": r_format,
"antiprop": r_anti,
"total": r_task + r_detect + r_format + r_anti
}
done = True
info = {
'rewards': rewards,
'injected_indices': self.state.injected_indices,
'episode_id': self.state.episode_id
}
obs = Observation(query='', context=[], step=1, difficulty=self.state.difficulty, n_passages=0)
return obs, rewards, done, info