Spaces:
Running
Running
| import glob | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| import uuid | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Dict, List, Any, Union | |
| import huggingface_hub | |
| from fastapi import FastAPI, UploadFile, Form, File | |
| from fastapi.responses import JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from starlette.middleware.cors import CORSMiddleware | |
| from src.backend.evaluation import compute_tasks_ratings | |
| from src.backend.submit_tools import unzip_predictions_from_zip | |
| from src.dataset.datasets_data import preload_all_datasets | |
| from src.backend.validation_tools import ( | |
| validate_submission_tasks_name, | |
| validate_submission_json, | |
| validate_submission_template, | |
| ) | |
| from src.task.task import Task | |
| from src.task.task_factory import ( | |
| tasks_factory, | |
| ) | |
| BASE_DIR = Path(__file__).resolve().parents[2] | |
| SRC_DIR = BASE_DIR / "src" | |
| sys.path.insert(0, str(SRC_DIR)) | |
| RESULTS_DIR = BASE_DIR / "src" / "backend" / "results" | |
| RESULTS_DIR.mkdir(parents=True, exist_ok=True) | |
| FRONTEND_DIR = BASE_DIR / "frontend" | |
| async def lifespan(application: FastAPI = None): # pylint: disable=unused-argument | |
| """Called before the backend comes online, is used to load datasets in memory.""" | |
| # Load the ML model | |
| try: | |
| token = os.environ.get("HF_TOKEN") | |
| huggingface_hub.login(token=token) | |
| preload_all_datasets() | |
| except Exception as e: | |
| error_message = f"The datasets could not be loaded : {e}" | |
| logging.critical(error_message) | |
| yield | |
| app = FastAPI(lifespan=lifespan) | |
| app.mount("/results", StaticFiles(directory=str(RESULTS_DIR)), name="results") | |
| front_end_info_message = f"The Front-end directory is: {FRONTEND_DIR}" | |
| logging.info(front_end_info_message) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_credentials=True, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def submit( | |
| email: str = Form(...), | |
| predictions_zip: UploadFile = File(...), | |
| display_name: str = Form(...), | |
| ): | |
| """Route for making submissions with user generated results. | |
| :param email : The email of the user's submission | |
| :param predictions_zip : The zip file of the user's predictions' | |
| :param display_name : The display name associated with the user's submission' | |
| """ | |
| logging.info("Starting submission") | |
| info_message = f"Submission from {email!r} as {display_name!r}." | |
| logging.info(info_message) | |
| zip_bytes = await predictions_zip.read() | |
| submission_json = unzip_predictions_from_zip(zip_bytes) | |
| validate_submission_template(submission_json) | |
| validate_submission_tasks_name(submission_json) | |
| validate_submission_json(submission_json) | |
| tasks: List[Task] = tasks_factory(submission_json) | |
| logging.info("Computation started") | |
| start = datetime.now() | |
| submission_response = compute_tasks_ratings(tasks=tasks, submission=submission_json) | |
| computation_time = datetime.now() - start | |
| info_message = f"Computation ended in {computation_time}" | |
| logging.info(info_message) | |
| submission_id = str(uuid.uuid4()) | |
| submission_response.update( | |
| { | |
| "display_name": display_name, | |
| "email": email, | |
| "submission_id": submission_id, | |
| } | |
| ) | |
| out_path = RESULTS_DIR / f"{submission_id}.json" | |
| with open(out_path, "w", encoding="utf-8") as f: | |
| json.dump(submission_response, f, ensure_ascii=False, indent=2) | |
| get_leaderboard_entries.cache_clear() | |
| return JSONResponse(content=submission_response) | |
| def get_leaderboard_entries() -> List[Dict[str, Any]]: | |
| """Returns all entries currently in the leaderboard. | |
| Supporte aussi les fichiers JSON qui contiennent une LISTE d'entrées | |
| et normalise les métriques 'plates' en groupes imbriqués pour le front. | |
| """ | |
| def _wrap_flat_metrics(task_payload: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Si task_payload est 'plat' (ex: {"accuracy": 94.2}), | |
| on le transforme en {"<group>": {...}} pour que le front puisse l'agréger. | |
| Règles de nommage du groupe : | |
| - présence de exact_match/f1 -> "fquad" | |
| - sinon présence de acc/accuracy -> "accuracy" | |
| - sinon présence de pearson/pearsonr/spearman -> "correlation" | |
| - sinon -> "metrics" | |
| Les valeurs >1 sont laissées telles quelles (le front normalise déjà % -> [0,1]). | |
| """ | |
| if not isinstance(task_payload, dict): | |
| return task_payload | |
| # si c'est déjà "imbriqué" (une valeur est un dict), on ne touche pas | |
| if any(isinstance(v, dict) for v in task_payload.values()): | |
| return task_payload | |
| keys = set(k.lower() for k in task_payload.keys()) | |
| if {"exact_match", "f1"} & keys: | |
| group = "fquad" | |
| elif {"accuracy", "acc"} & keys: | |
| group = "accuracy" | |
| elif {"pearson", "pearsonr", "spearman"} & keys: | |
| group = "correlation" | |
| else: | |
| group = "metrics" | |
| # Rien de spécial pour les warnings ici : le front les considère optionnels | |
| # et s'attend à "<group>_warning" dans l'objet interne si on veut en fournir. | |
| return {group: task_payload} | |
| entries: List[Dict[str, Any]] = [] | |
| for filepath in glob.glob(str(RESULTS_DIR / "*.json")): | |
| try: | |
| with open(filepath, encoding="utf-8") as f: | |
| data = json.load(f) | |
| # Fonction interne qui traite UNE entrée (dict) au bon format minimal | |
| def process_entry(entry: Dict[str, Any]) -> Union[Dict[str, Any], None]: | |
| if not isinstance(entry, dict): | |
| return None | |
| if "model_name" not in entry or "tasks" not in entry: | |
| return None | |
| # Re-construire "results" comme le front s'y attend | |
| results = {} | |
| for task_obj in entry.get("tasks", []): | |
| if not isinstance(task_obj, dict) or len(task_obj) != 1: | |
| continue | |
| task_name, payload = list(task_obj.items())[0] | |
| normalized = _wrap_flat_metrics(payload) | |
| results[task_name] = normalized | |
| if not results: | |
| return None | |
| return { | |
| "submission_id": entry.get("submission_id") or str(uuid.uuid4()), | |
| "display_name": entry.get("display_name") | |
| or entry.get("model_name") | |
| or "Unnamed Model", | |
| "email": entry.get("email", "N/A"), | |
| "results": results, | |
| } | |
| # Le fichier peut contenir UNE entrée (dict) ou PLUSIEURS (list) | |
| if isinstance(data, list): | |
| for item in data: | |
| processed = process_entry(item) | |
| if processed: | |
| entries.append(processed) | |
| else: | |
| processed = process_entry(data) | |
| if processed: | |
| entries.append(processed) | |
| except Exception as e: | |
| logging_message = f"Error processing file '{filepath}': {e}" | |
| logging.error(logging_message) | |
| continue | |
| return entries | |
| async def leaderboard() -> List[Dict[str, Any]]: | |
| return get_leaderboard_entries() | |
| async def health_check(): | |
| return {"status": "healthy", "message": "API is running."} | |
| async def home(): | |
| return {"status": "working"} | |