File size: 3,376 Bytes
8fa3acc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import logging
from typing import Dict, List

from fastapi import HTTPException

tasks_name = [
    "allocine",
    "fquad",
    "gqnli",
    "paws_x",
    "piaf",
    "qfrblimp",
    "qfrcola",
    "sickfr",
    "sts22",
    "xnli",
]


def validate_submission_template(dictionary: Dict) -> None:
    """Ensures the dictionnary follows the correct format.
    :param dictionary: Dictionary to validate."""
    if dictionary.get("model_name", None) is None:
        error = "The submission is missing a model name."
        logging.error(error)
        raise HTTPException(200, error)
    if dictionary.get("model_url", None) is None:
        error = "The submission is missing a model URL."
        logging.error(error)
        raise HTTPException(200, error)
    if dictionary.get("tasks", None) is None:
        error = "The submission is missing a tasks keyword."
        logging.error(error)
        raise HTTPException(200, error)

    tasks = dictionary.get("tasks")
    if not isinstance(tasks, List):
        error = (
            "The tasks keyword value must be a list of dictionaries where they key is the tasks "
            "and value is a dictionary of predictions (in a list format). See our documentation for"
            "a template."
        )
        logging.error(error)
        raise HTTPException(200, error)

    for task in tasks:
        if len(task.keys()) > 1:
            error = (
                "Each task must be a dictionary of one element where the key is "
                "the task name and the value is a list."
            )
            logging.error(error)
            raise HTTPException(200, error)


def validate_submission_tasks_name(dictionary: Dict) -> None:
    """
    Validate if the submission JSON key are the tasks name.
    """
    for task in dictionary.get("tasks"):
        key = list(task.keys())[0]
        if key not in tasks_name:
            error = f"Unknown key '{key}' in the submission JSON. The expected tasks are: {tasks_name}."
            logging.error(error)
            raise HTTPException(200, error)


def validate_submission_json(dictionary: Dict) -> None:
    """Validates that the submitted json is in the correct format.
    :param dictionary: Dictionary to validate."""
    task_payload = dictionary.get("tasks")

    for task in task_payload:
        for task_name, payload in task.items():
            if not isinstance(payload, dict):
                error = (
                    "The tasks payload must be a dictionary in the format '{'prediction': [<predictions>]}' "
                    "for each task."
                )
                logging.error(error)
                raise HTTPException(200, error)
            for key, value in payload.items():
                if key not in ["predictions", "prediction"]:
                    error = f"The task '{task_name}' payload does not have the expected key: 'predictions'."
                    logging.error(error)
                    raise HTTPException(200, error)
                if not isinstance(value, list):
                    error = (
                        f"The task '{task_name}' predictions payload is not in a list format. "
                        r"The expected format is: '{'prediction': [<predictions>]}'"
                    )
                    logging.error(error)
                    raise HTTPException(200, error)