Spaces:
Sleeping
Sleeping
File size: 1,794 Bytes
5f54992 7a8a0f0 5f54992 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | import json
import random
from pathlib import Path
from environment.env import ContextCorruptionEnv
from environment.actions import ContextCorruptionAction, ActionType
NUM_EPISODES = 100
RESULTS_PATH = Path(__file__).parent / "baseline_results.json"
def run_baseline():
env = ContextCorruptionEnv()
rewards = []
for ep in range(NUM_EPISODES):
obs = env.reset()
done = False
while not done:
# Randomly flag 0-4 docs
if random.random() < 0.4 and obs.budget_remaining > 1:
doc_id = random.randint(0, len(obs.documents) - 1)
action = ContextCorruptionAction(
action_type=ActionType.flag_suspicious,
doc_id=doc_id,
)
else:
action = ContextCorruptionAction(
action_type=ActionType.submit_answer,
answer="unknown",
confidence=0.5,
)
obs = env.step(action)
done = obs.done
rewards.append(obs.reward)
if (ep + 1) % 10 == 0:
print(f"Episode {ep + 1}/{NUM_EPISODES} | reward: {obs.reward:.4f} | avg so far: {sum(rewards)/len(rewards):.4f}")
avg = round(sum(rewards) / len(rewards), 4)
minimum = round(min(rewards), 4)
maximum = round(max(rewards), 4)
results = {
"num_episodes": NUM_EPISODES,
"avg_reward": avg,
"min_reward": minimum,
"max_reward": maximum,
"all_rewards": rewards,
}
RESULTS_PATH.write_text(json.dumps(results, indent=2))
print(f"\nBaseline results — avg: {avg} | min: {minimum} | max: {maximum}")
print(f"Saved to {RESULTS_PATH}")
return results
if __name__ == "__main__":
run_baseline()
|