File size: 1,794 Bytes
5f54992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a8a0f0
5f54992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import json
import random
from pathlib import Path

from environment.env import ContextCorruptionEnv
from environment.actions import ContextCorruptionAction, ActionType

NUM_EPISODES = 100
RESULTS_PATH = Path(__file__).parent / "baseline_results.json"


def run_baseline():
    env = ContextCorruptionEnv()
    rewards = []

    for ep in range(NUM_EPISODES):
        obs = env.reset()
        done = False

        while not done:
            # Randomly flag 0-4 docs
            if random.random() < 0.4 and obs.budget_remaining > 1:
                doc_id = random.randint(0, len(obs.documents) - 1)
                action = ContextCorruptionAction(
                    action_type=ActionType.flag_suspicious,
                    doc_id=doc_id,
                )
            else:
                action = ContextCorruptionAction(
                    action_type=ActionType.submit_answer,
                    answer="unknown",
                    confidence=0.5,
                )

            obs = env.step(action)
            done = obs.done

        rewards.append(obs.reward)
        if (ep + 1) % 10 == 0:
            print(f"Episode {ep + 1}/{NUM_EPISODES} | reward: {obs.reward:.4f} | avg so far: {sum(rewards)/len(rewards):.4f}")

    avg = round(sum(rewards) / len(rewards), 4)
    minimum = round(min(rewards), 4)
    maximum = round(max(rewards), 4)

    results = {
        "num_episodes": NUM_EPISODES,
        "avg_reward": avg,
        "min_reward": minimum,
        "max_reward": maximum,
        "all_rewards": rewards,
    }

    RESULTS_PATH.write_text(json.dumps(results, indent=2))
    print(f"\nBaseline results — avg: {avg} | min: {minimum} | max: {maximum}")
    print(f"Saved to {RESULTS_PATH}")
    return results


if __name__ == "__main__":
    run_baseline()