File size: 7,688 Bytes
8fa3acc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import glob
import json
import logging
import os
import sys
import uuid
from contextlib import asynccontextmanager
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Any, Union

import huggingface_hub
from fastapi import FastAPI, UploadFile, Form, File
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles

from starlette.middleware.cors import CORSMiddleware

from src.backend.evaluation import compute_tasks_ratings
from src.backend.submit_tools import unzip_predictions_from_zip
from src.dataset.datasets_data import preload_all_datasets
from src.backend.validation_tools import (
    validate_submission_tasks_name,
    validate_submission_json,
    validate_submission_template,
)
from src.task.task import Task
from src.task.task_factory import (
    tasks_factory,
)


BASE_DIR = Path(__file__).resolve().parents[2]
SRC_DIR = BASE_DIR / "src"
sys.path.insert(0, str(SRC_DIR))

RESULTS_DIR = BASE_DIR / "src" / "backend" / "results"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
FRONTEND_DIR = BASE_DIR / "frontend"


@asynccontextmanager
async def lifespan(application: FastAPI = None):  # pylint: disable=unused-argument
    """Called before the backend comes online, is used to load datasets in memory."""
    # Load the ML model
    try:
        token = os.environ.get("HF_TOKEN")
        huggingface_hub.login(token=token)
        preload_all_datasets()
    except Exception as e:
        error_message = f"The datasets could not be loaded : {e}"
        logging.critical(error_message)

    yield


app = FastAPI(lifespan=lifespan)
app.mount("/results", StaticFiles(directory=str(RESULTS_DIR)), name="results")
front_end_info_message = f"The Front-end directory is: {FRONTEND_DIR}"
logging.info(front_end_info_message)

app.add_middleware(
    CORSMiddleware,
    allow_credentials=True,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.post("/submit")
async def submit(
    email: str = Form(...),
    predictions_zip: UploadFile = File(...),
    display_name: str = Form(...),
):
    """Route for making submissions with user generated results.
    :param email : The email of the user's submission
    :param predictions_zip : The zip file of the user's predictions'
    :param display_name : The display name associated with the user's submission'
    """
    logging.info("Starting submission")
    info_message = f"Submission from {email!r} as {display_name!r}."
    logging.info(info_message)
    zip_bytes = await predictions_zip.read()
    submission_json = unzip_predictions_from_zip(zip_bytes)

    validate_submission_template(submission_json)
    validate_submission_tasks_name(submission_json)
    validate_submission_json(submission_json)

    tasks: List[Task] = tasks_factory(submission_json)
    logging.info("Computation started")
    start = datetime.now()
    submission_response = compute_tasks_ratings(tasks=tasks, submission=submission_json)
    computation_time = datetime.now() - start
    info_message = f"Computation ended in {computation_time}"
    logging.info(info_message)
    submission_id = str(uuid.uuid4())
    submission_response.update(
        {
            "display_name": display_name,
            "email": email,
            "submission_id": submission_id,
        }
    )

    out_path = RESULTS_DIR / f"{submission_id}.json"
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(submission_response, f, ensure_ascii=False, indent=2)

    get_leaderboard_entries.cache_clear()

    return JSONResponse(content=submission_response)


@lru_cache(maxsize=1)
def get_leaderboard_entries() -> List[Dict[str, Any]]:
    """Returns all entries currently in the leaderboard.
    Supporte aussi les fichiers JSON qui contiennent une LISTE d'entrées
    et normalise les métriques 'plates' en groupes imbriqués pour le front.
    """

    def _wrap_flat_metrics(task_payload: Dict[str, Any]) -> Dict[str, Any]:
        """
        Si task_payload est 'plat' (ex: {"accuracy": 94.2}),
        on le transforme en {"<group>": {...}} pour que le front puisse l'agréger.
        Règles de nommage du groupe :
          - présence de exact_match/f1 -> "fquad"
          - sinon présence de acc/accuracy -> "accuracy"
          - sinon présence de pearson/pearsonr/spearman -> "correlation"
          - sinon -> "metrics"
        Les valeurs >1 sont laissées telles quelles (le front normalise déjà % -> [0,1]).
        """
        if not isinstance(task_payload, dict):
            return task_payload

        # si c'est déjà "imbriqué" (une valeur est un dict), on ne touche pas
        if any(isinstance(v, dict) for v in task_payload.values()):
            return task_payload

        keys = set(k.lower() for k in task_payload.keys())
        if {"exact_match", "f1"} & keys:
            group = "fquad"
        elif {"accuracy", "acc"} & keys:
            group = "accuracy"
        elif {"pearson", "pearsonr", "spearman"} & keys:
            group = "correlation"
        else:
            group = "metrics"

        # Rien de spécial pour les warnings ici : le front les considère optionnels
        # et s'attend à "<group>_warning" dans l'objet interne si on veut en fournir.
        return {group: task_payload}

    entries: List[Dict[str, Any]] = []

    for filepath in glob.glob(str(RESULTS_DIR / "*.json")):
        try:
            with open(filepath, encoding="utf-8") as f:
                data = json.load(f)

            # Fonction interne qui traite UNE entrée (dict) au bon format minimal
            def process_entry(entry: Dict[str, Any]) -> Union[Dict[str, Any], None]:
                if not isinstance(entry, dict):
                    return None
                if "model_name" not in entry or "tasks" not in entry:
                    return None

                # Re-construire "results" comme le front s'y attend
                results = {}
                for task_obj in entry.get("tasks", []):
                    if not isinstance(task_obj, dict) or len(task_obj) != 1:
                        continue
                    task_name, payload = list(task_obj.items())[0]
                    normalized = _wrap_flat_metrics(payload)
                    results[task_name] = normalized

                if not results:
                    return None

                return {
                    "submission_id": entry.get("submission_id") or str(uuid.uuid4()),
                    "display_name": entry.get("display_name")
                    or entry.get("model_name")
                    or "Unnamed Model",
                    "email": entry.get("email", "N/A"),
                    "results": results,
                }

            # Le fichier peut contenir UNE entrée (dict) ou PLUSIEURS (list)
            if isinstance(data, list):
                for item in data:
                    processed = process_entry(item)
                    if processed:
                        entries.append(processed)
            else:
                processed = process_entry(data)
                if processed:
                    entries.append(processed)

        except Exception as e:
            logging_message = f"Error processing file '{filepath}': {e}"
            logging.error(logging_message)
            continue

    return entries


@app.get("/leaderboard")
async def leaderboard() -> List[Dict[str, Any]]:

    return get_leaderboard_entries()


@app.get("/health")
async def health_check():
    return {"status": "healthy", "message": "API is running."}


@app.get("/")
async def home():
    return {"status": "working"}