Spaces:
Sleeping
Sleeping
| import json | |
| import random | |
| from pathlib import Path | |
| from environment.env import ContextCorruptionEnv | |
| from environment.actions import ContextCorruptionAction, ActionType | |
| NUM_EPISODES = 100 | |
| RESULTS_PATH = Path(__file__).parent / "baseline_results.json" | |
| def run_baseline(): | |
| env = ContextCorruptionEnv() | |
| rewards = [] | |
| for ep in range(NUM_EPISODES): | |
| obs = env.reset() | |
| done = False | |
| while not done: | |
| # Randomly flag 0-4 docs | |
| if random.random() < 0.4 and obs.budget_remaining > 1: | |
| doc_id = random.randint(0, len(obs.documents) - 1) | |
| action = ContextCorruptionAction( | |
| action_type=ActionType.flag_suspicious, | |
| doc_id=doc_id, | |
| ) | |
| else: | |
| action = ContextCorruptionAction( | |
| action_type=ActionType.submit_answer, | |
| answer="unknown", | |
| confidence=0.5, | |
| ) | |
| obs = env.step(action) | |
| done = obs.done | |
| rewards.append(obs.reward) | |
| if (ep + 1) % 10 == 0: | |
| print(f"Episode {ep + 1}/{NUM_EPISODES} | reward: {obs.reward:.4f} | avg so far: {sum(rewards)/len(rewards):.4f}") | |
| avg = round(sum(rewards) / len(rewards), 4) | |
| minimum = round(min(rewards), 4) | |
| maximum = round(max(rewards), 4) | |
| results = { | |
| "num_episodes": NUM_EPISODES, | |
| "avg_reward": avg, | |
| "min_reward": minimum, | |
| "max_reward": maximum, | |
| "all_rewards": rewards, | |
| } | |
| RESULTS_PATH.write_text(json.dumps(results, indent=2)) | |
| print(f"\nBaseline results — avg: {avg} | min: {minimum} | max: {maximum}") | |
| print(f"Saved to {RESULTS_PATH}") | |
| return results | |
| if __name__ == "__main__": | |
| run_baseline() | |