import json import random from pathlib import Path from environment.env import ContextCorruptionEnv from environment.actions import ContextCorruptionAction, ActionType NUM_EPISODES = 100 RESULTS_PATH = Path(__file__).parent / "baseline_results.json" def run_baseline(): env = ContextCorruptionEnv() rewards = [] for ep in range(NUM_EPISODES): obs = env.reset() done = False while not done: # Randomly flag 0-4 docs if random.random() < 0.4 and obs.budget_remaining > 1: doc_id = random.randint(0, len(obs.documents) - 1) action = ContextCorruptionAction( action_type=ActionType.flag_suspicious, doc_id=doc_id, ) else: action = ContextCorruptionAction( action_type=ActionType.submit_answer, answer="unknown", confidence=0.5, ) obs = env.step(action) done = obs.done rewards.append(obs.reward) if (ep + 1) % 10 == 0: print(f"Episode {ep + 1}/{NUM_EPISODES} | reward: {obs.reward:.4f} | avg so far: {sum(rewards)/len(rewards):.4f}") avg = round(sum(rewards) / len(rewards), 4) minimum = round(min(rewards), 4) maximum = round(max(rewards), 4) results = { "num_episodes": NUM_EPISODES, "avg_reward": avg, "min_reward": minimum, "max_reward": maximum, "all_rewards": rewards, } RESULTS_PATH.write_text(json.dumps(results, indent=2)) print(f"\nBaseline results — avg: {avg} | min: {minimum} | max: {maximum}") print(f"Saved to {RESULTS_PATH}") return results if __name__ == "__main__": run_baseline()