File size: 866 Bytes
51d5ddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""Shared evaluation constants and policies for Sakha."""

from sakha.graders import score_easy_task, score_medium_task, score_hard_task
from sakha.models import ActionType, SakhaAction

TASK_GRADERS = {
    "easy": score_easy_task,
    "medium": score_medium_task,
    "hard": score_hard_task,
}

PATIENT_COUNTS = {"easy": 5, "medium": 8, "hard": 18}


def noop_policy(obs, step, pc):
    return SakhaAction(action_type=ActionType.NOOP, patient_id=None)


def greedy_policy(obs, step, pc):
    return SakhaAction(action_type=ActionType.ADMINISTER_MEDICINE, patient_id=(step % pc) + 1)


def priority_policy(obs, step, pc):
    if obs.ward_state.pending_tasks:
        task = obs.ward_state.pending_tasks[0]
        return SakhaAction(action_type=task.required_action, patient_id=task.patient_id)
    return SakhaAction(action_type=ActionType.NOOP, patient_id=None)