context-corruption-env / eval /baseline_eval.py
Siddh12334's picture
feat: rewrite env to be fully openenv-core compliant
7a8a0f0
import json
import random
from pathlib import Path
from environment.env import ContextCorruptionEnv
from environment.actions import ContextCorruptionAction, ActionType
NUM_EPISODES = 100
RESULTS_PATH = Path(__file__).parent / "baseline_results.json"
def run_baseline():
env = ContextCorruptionEnv()
rewards = []
for ep in range(NUM_EPISODES):
obs = env.reset()
done = False
while not done:
# Randomly flag 0-4 docs
if random.random() < 0.4 and obs.budget_remaining > 1:
doc_id = random.randint(0, len(obs.documents) - 1)
action = ContextCorruptionAction(
action_type=ActionType.flag_suspicious,
doc_id=doc_id,
)
else:
action = ContextCorruptionAction(
action_type=ActionType.submit_answer,
answer="unknown",
confidence=0.5,
)
obs = env.step(action)
done = obs.done
rewards.append(obs.reward)
if (ep + 1) % 10 == 0:
print(f"Episode {ep + 1}/{NUM_EPISODES} | reward: {obs.reward:.4f} | avg so far: {sum(rewards)/len(rewards):.4f}")
avg = round(sum(rewards) / len(rewards), 4)
minimum = round(min(rewards), 4)
maximum = round(max(rewards), 4)
results = {
"num_episodes": NUM_EPISODES,
"avg_reward": avg,
"min_reward": minimum,
"max_reward": maximum,
"all_rewards": rewards,
}
RESULTS_PATH.write_text(json.dumps(results, indent=2))
print(f"\nBaseline results — avg: {avg} | min: {minimum} | max: {maximum}")
print(f"Saved to {RESULTS_PATH}")
return results
if __name__ == "__main__":
run_baseline()