diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..aa96e1820a29689a9163b6a202ac15bf5d81fd8e --- /dev/null +++ b/.env.example @@ -0,0 +1,3 @@ +ENVIRONMENT=development +HF_TOKEN=xxx +HF_HOME=.cache diff --git a/.gitignore b/.gitignore index 46141db48eda31c2c991a0b4e2f0ffdbc38b6283..6e3ad8b0d89930c18ec774d6d43be66f68c225a3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,13 +1,32 @@ -auto_evals/ -venv/ -__pycache__/ +__pycache__ +.cache/ + +# dependencies +frontend/node_modules +/.pnp +.pnp.js + +# testing +/coverage + +# production +/build + +# misc +.DS_Store +.env.local +.env.development.local +.env.test.local +.env.production.local + +npm-debug.log* +yarn-debug.log* +yarn-error.log* + +yarn.lock +package-lock.json + +# Environment variables .env -.ipynb_checkpoints -*ipynb -.vscode/ - -eval-queue/ -eval-results/ -eval-queue-bk/ -eval-results-bk/ -logs/ +.env.* +!.env.example diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 0710dad252bda2ac9fd5b7e4e2e4dc0afeff43cf..0000000000000000000000000000000000000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -default_language_version: - python: python3 - -ci: - autofix_prs: true - autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' - autoupdate_schedule: quarterly - -repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 - hooks: - - id: check-yaml - - id: check-case-conflict - - id: detect-private-key - - id: check-added-large-files - args: ['--maxkb=1000'] - - id: requirements-txt-fixer - - id: end-of-file-fixer - - id: trailing-whitespace - - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - name: Format imports - - - repo: https://github.com/psf/black - rev: 22.12.0 - hooks: - - id: black - name: Format code - additional_dependencies: ['click==8.0.2'] - - - repo: https://github.com/charliermarsh/ruff-pre-commit - # Ruff version. - rev: 'v0.0.267' - hooks: - - id: ruff diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..0179a990e3776616c71d25cac43206411ed1eb49 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,62 @@ +# Build frontend +FROM node:18 as frontend-build +WORKDIR /app +COPY frontend/package*.json ./ +RUN npm install +COPY frontend/ ./ + +RUN npm run build + +# Build backend +FROM python:3.12-slim +WORKDIR /app + +# Create non-root user +RUN useradd -m -u 1000 user + +# Install poetry +RUN pip install poetry + +# Create and configure cache directory +RUN mkdir -p /app/.cache && \ + chown -R user:user /app + +# Copy and install backend dependencies +COPY backend/pyproject.toml backend/poetry.lock* ./ +RUN poetry config virtualenvs.create false \ + && poetry install --no-interaction --no-ansi --no-root --only main + +# Copy backend code +COPY backend/ . + +# Install Node.js and npm +RUN apt-get update && apt-get install -y \ + curl \ + netcat-openbsd \ + && curl -fsSL https://deb.nodesource.com/setup_18.x | bash - \ + && apt-get install -y nodejs \ + && rm -rf /var/lib/apt/lists/* + +# Copy frontend server and build +COPY --from=frontend-build /app/build ./frontend/build +COPY --from=frontend-build /app/package*.json ./frontend/ +COPY --from=frontend-build /app/server.js ./frontend/ + +# Install frontend production dependencies +WORKDIR /app/frontend +RUN npm install --production +WORKDIR /app + +# Environment variables +ENV HF_HOME=/app/.cache \ + HF_DATASETS_CACHE=/app/.cache \ + INTERNAL_API_PORT=7861 \ + PORT=7860 \ + NODE_ENV=production + +# Note: HF_TOKEN should be provided at runtime, not build time +USER user +EXPOSE 7860 + +# Start both servers with wait-for +CMD ["sh", "-c", "uvicorn app.asgi:app --host 0.0.0.0 --port 7861 & while ! nc -z localhost 7861; do sleep 1; done && cd frontend && npm run serve"] diff --git a/Makefile b/Makefile deleted file mode 100644 index b5685772804c8af4235a8504dc6752bfc9ae5d1d..0000000000000000000000000000000000000000 --- a/Makefile +++ /dev/null @@ -1,13 +0,0 @@ -.PHONY: style format - - -style: - python -m black --line-length 119 . - python -m isort . - ruff check --fix . - - -quality: - python -m black --check --line-length 119 . - python -m isort --check-only . - ruff check . diff --git a/README.md b/README.md index 8a8207655226be38d2bd6da0e847aa5c2f4bd07d..2a8355b2a544a73199f41ae10d4fe752059c8983 100644 --- a/README.md +++ b/README.md @@ -1,48 +1,16 @@ --- title: EEG Finetune Arena -emoji: 🥇 +emoji: 🧠 colorFrom: green colorTo: indigo -sdk: gradio -app_file: app.py +sdk: docker +hf_oauth: true pinned: true license: apache-2.0 -short_description: Duplicate this leaderboard to initialize your own! -sdk_version: 5.43.1 +short_description: Comparing EEG models in an open and reproducible way tags: - leaderboard +- eeg +- braindecode +- neuroscience --- - -# Start the configuration - -Most of the variables to change for a default leaderboard are in `src/env.py` (replace the path for your leaderboard) and `src/about.py` (for tasks). - -Results files should have the following format and be stored as json files: -```json -{ - "config": { - "model_dtype": "torch.float16", # or torch.bfloat16 or 8bit or 4bit - "model_name": "path of the model on the hub: org/model", - "model_sha": "revision on the hub", - }, - "results": { - "task_name": { - "metric_name": score, - }, - "task_name2": { - "metric_name": score, - } - } -} -``` - -Request files are created automatically by this tool. - -If you encounter problem on the space, don't hesitate to restart it to remove the create eval-queue, eval-queue-bk, eval-results and eval-results-bk created folder. - -# Code logic for more complex edits - -You'll find -- the main table' columns names and properties in `src/display/utils.py` -- the logic to read all results and request files, then convert them in dataframe lines, in `src/leaderboard/read_evals.py`, and `src/populate.py` -- the logic to allow or filter submissions in `src/submission/submit.py` and `src/submission/check_validity.py` \ No newline at end of file diff --git a/app.py b/app.py deleted file mode 100644 index 9d9ea9f3ec9648615a76ac2b9427089765d3b9ec..0000000000000000000000000000000000000000 --- a/app.py +++ /dev/null @@ -1,204 +0,0 @@ -import gradio as gr -from gradio_leaderboard import Leaderboard, ColumnFilter, SelectColumns -import pandas as pd -from apscheduler.schedulers.background import BackgroundScheduler -from huggingface_hub import snapshot_download - -from src.about import ( - CITATION_BUTTON_LABEL, - CITATION_BUTTON_TEXT, - EVALUATION_QUEUE_TEXT, - INTRODUCTION_TEXT, - LLM_BENCHMARKS_TEXT, - TITLE, -) -from src.display.css_html_js import custom_css -from src.display.utils import ( - BENCHMARK_COLS, - COLS, - EVAL_COLS, - EVAL_TYPES, - AutoEvalColumn, - ModelType, - fields, - WeightType, - Precision -) -from src.envs import API, EVAL_REQUESTS_PATH, EVAL_RESULTS_PATH, QUEUE_REPO, REPO_ID, RESULTS_REPO, TOKEN -from src.populate import get_evaluation_queue_df, get_leaderboard_df -from src.submission.submit import add_new_eval - - -def restart_space(): - API.restart_space(repo_id=REPO_ID) - -### Space initialisation -try: - print(EVAL_REQUESTS_PATH) - snapshot_download( - repo_id=QUEUE_REPO, local_dir=EVAL_REQUESTS_PATH, repo_type="dataset", tqdm_class=None, etag_timeout=30, token=TOKEN - ) -except Exception: - restart_space() -try: - print(EVAL_RESULTS_PATH) - snapshot_download( - repo_id=RESULTS_REPO, local_dir=EVAL_RESULTS_PATH, repo_type="dataset", tqdm_class=None, etag_timeout=30, token=TOKEN - ) -except Exception: - restart_space() - - -LEADERBOARD_DF = get_leaderboard_df(EVAL_RESULTS_PATH, EVAL_REQUESTS_PATH, COLS, BENCHMARK_COLS) - -( - finished_eval_queue_df, - running_eval_queue_df, - pending_eval_queue_df, -) = get_evaluation_queue_df(EVAL_REQUESTS_PATH, EVAL_COLS) - -def init_leaderboard(dataframe): - if dataframe is None or dataframe.empty: - raise ValueError("Leaderboard DataFrame is empty or None.") - return Leaderboard( - value=dataframe, - datatype=[c.type for c in fields(AutoEvalColumn)], - select_columns=SelectColumns( - default_selection=[c.name for c in fields(AutoEvalColumn) if c.displayed_by_default], - cant_deselect=[c.name for c in fields(AutoEvalColumn) if c.never_hidden], - label="Select Columns to Display:", - ), - search_columns=[AutoEvalColumn.model.name, AutoEvalColumn.license.name], - hide_columns=[c.name for c in fields(AutoEvalColumn) if c.hidden], - filter_columns=[ - ColumnFilter(AutoEvalColumn.model_type.name, type="checkboxgroup", label="Model types"), - ColumnFilter(AutoEvalColumn.precision.name, type="checkboxgroup", label="Precision"), - ColumnFilter( - AutoEvalColumn.params.name, - type="slider", - min=0.01, - max=150, - label="Select the number of parameters (B)", - ), - ColumnFilter( - AutoEvalColumn.still_on_hub.name, type="boolean", label="Deleted/incomplete", default=True - ), - ], - bool_checkboxgroup_label="Hide models", - interactive=False, - ) - - -demo = gr.Blocks(css=custom_css) -with demo: - gr.HTML(TITLE) - gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text") - - with gr.Tabs(elem_classes="tab-buttons") as tabs: - with gr.TabItem("🏅 LLM Benchmark", elem_id="llm-benchmark-tab-table", id=0): - leaderboard = init_leaderboard(LEADERBOARD_DF) - - with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=2): - gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text") - - with gr.TabItem("🚀 Submit here! ", elem_id="llm-benchmark-tab-table", id=3): - with gr.Column(): - with gr.Row(): - gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text") - - with gr.Column(): - with gr.Accordion( - f"✅ Finished Evaluations ({len(finished_eval_queue_df)})", - open=False, - ): - with gr.Row(): - finished_eval_table = gr.components.Dataframe( - value=finished_eval_queue_df, - headers=EVAL_COLS, - datatype=EVAL_TYPES, - row_count=5, - ) - with gr.Accordion( - f"🔄 Running Evaluation Queue ({len(running_eval_queue_df)})", - open=False, - ): - with gr.Row(): - running_eval_table = gr.components.Dataframe( - value=running_eval_queue_df, - headers=EVAL_COLS, - datatype=EVAL_TYPES, - row_count=5, - ) - - with gr.Accordion( - f"⏳ Pending Evaluation Queue ({len(pending_eval_queue_df)})", - open=False, - ): - with gr.Row(): - pending_eval_table = gr.components.Dataframe( - value=pending_eval_queue_df, - headers=EVAL_COLS, - datatype=EVAL_TYPES, - row_count=5, - ) - with gr.Row(): - gr.Markdown("# ✉️✨ Submit your model here!", elem_classes="markdown-text") - - with gr.Row(): - with gr.Column(): - model_name_textbox = gr.Textbox(label="Model name") - revision_name_textbox = gr.Textbox(label="Revision commit", placeholder="main") - model_type = gr.Dropdown( - choices=[t.to_str(" : ") for t in ModelType if t != ModelType.Unknown], - label="Model type", - multiselect=False, - value=None, - interactive=True, - ) - - with gr.Column(): - precision = gr.Dropdown( - choices=[i.value.name for i in Precision if i != Precision.Unknown], - label="Precision", - multiselect=False, - value="float16", - interactive=True, - ) - weight_type = gr.Dropdown( - choices=[i.value.name for i in WeightType], - label="Weights type", - multiselect=False, - value="Original", - interactive=True, - ) - base_model_name_textbox = gr.Textbox(label="Base model (for delta or adapter weights)") - - submit_button = gr.Button("Submit Eval") - submission_result = gr.Markdown() - submit_button.click( - add_new_eval, - [ - model_name_textbox, - base_model_name_textbox, - revision_name_textbox, - precision, - weight_type, - model_type, - ], - submission_result, - ) - - with gr.Row(): - with gr.Accordion("📙 Citation", open=False): - citation_button = gr.Textbox( - value=CITATION_BUTTON_TEXT, - label=CITATION_BUTTON_LABEL, - lines=20, - elem_id="citation-button", - show_copy_button=True, - ) - -scheduler = BackgroundScheduler() -scheduler.add_job(restart_space, "interval", seconds=1800) -scheduler.start() -demo.queue(default_concurrency_limit=40).launch() \ No newline at end of file diff --git a/backend/Dockerfile.dev b/backend/Dockerfile.dev new file mode 100644 index 0000000000000000000000000000000000000000..f802c87f0d5d730c559b1f21ed715b48cc9ca42a --- /dev/null +++ b/backend/Dockerfile.dev @@ -0,0 +1,25 @@ +FROM python:3.12-slim + +WORKDIR /app + +# Install required system dependencies +RUN apt-get update && apt-get install -y \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +# Install poetry +RUN pip install poetry + +# Copy Poetry configuration files +COPY pyproject.toml poetry.lock* ./ + +# Install dependencies +RUN poetry config virtualenvs.create false && \ + poetry install --no-interaction --no-ansi --no-root + +# Environment variables configuration for logs +ENV PYTHONUNBUFFERED=1 +ENV LOG_LEVEL=INFO + +# In dev, mount volume directly +CMD ["uvicorn", "app.asgi:app", "--host", "0.0.0.0", "--port", "7860", "--reload", "--log-level", "warning", "--no-access-log"] \ No newline at end of file diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..41bd81293794127ec484666c9a9bf3b2cd0bbe3c --- /dev/null +++ b/backend/app/api/__init__.py @@ -0,0 +1,5 @@ +""" +API package initialization +""" + +__all__ = ["endpoints"] diff --git a/backend/app/api/dependencies.py b/backend/app/api/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..7222423860513bbad0e167f0f012126feaecb463 --- /dev/null +++ b/backend/app/api/dependencies.py @@ -0,0 +1,34 @@ +from fastapi import Depends, HTTPException +import logging +from app.services.models import ModelService +from app.services.votes import VoteService +from app.core.formatting import LogFormatter + +logger = logging.getLogger(__name__) + +model_service = ModelService() +vote_service = VoteService() + +async def get_model_service() -> ModelService: + """Dependency to get ModelService instance""" + try: + logger.info(LogFormatter.info("Initializing model service dependency")) + await model_service.initialize() + logger.info(LogFormatter.success("Model service initialized")) + return model_service + except Exception as e: + error_msg = "Failed to initialize model service" + logger.error(LogFormatter.error(error_msg, e)) + raise HTTPException(status_code=500, detail=str(e)) + +async def get_vote_service() -> VoteService: + """Dependency to get VoteService instance""" + try: + logger.info(LogFormatter.info("Initializing vote service dependency")) + await vote_service.initialize() + logger.info(LogFormatter.success("Vote service initialized")) + return vote_service + except Exception as e: + error_msg = "Failed to initialize vote service" + logger.error(LogFormatter.error(error_msg, e)) + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/backend/app/api/endpoints/leaderboard.py b/backend/app/api/endpoints/leaderboard.py new file mode 100644 index 0000000000000000000000000000000000000000..b92e4cce98c1a43d6c8eb561c53db7cc5ae7e302 --- /dev/null +++ b/backend/app/api/endpoints/leaderboard.py @@ -0,0 +1,49 @@ +from fastapi import APIRouter +from typing import List, Dict, Any +from app.services.leaderboard import LeaderboardService +from app.core.fastapi_cache import cached, build_cache_key +import logging +from app.core.formatting import LogFormatter + +logger = logging.getLogger(__name__) +router = APIRouter() +leaderboard_service = LeaderboardService() + +def leaderboard_key_builder(func, namespace: str = "leaderboard", **kwargs): + """Build cache key for leaderboard data""" + key_type = "raw" if func.__name__ == "get_leaderboard" else "formatted" + key = build_cache_key(namespace, key_type) + logger.debug(LogFormatter.info(f"Built leaderboard cache key: {key}")) + return key + +@router.get("") +@cached(expire=300, key_builder=leaderboard_key_builder) +async def get_leaderboard() -> List[Dict[str, Any]]: + """ + Get raw leaderboard data + Response will be automatically GZIP compressed if size > 500 bytes + """ + try: + logger.info(LogFormatter.info("Fetching raw leaderboard data")) + data = await leaderboard_service.fetch_raw_data() + logger.info(LogFormatter.success(f"Retrieved {len(data)} leaderboard entries")) + return data + except Exception as e: + logger.error(LogFormatter.error("Failed to fetch raw leaderboard data", e)) + raise + +@router.get("/formatted") +@cached(expire=300, key_builder=leaderboard_key_builder) +async def get_formatted_leaderboard() -> List[Dict[str, Any]]: + """ + Get formatted leaderboard data with restructured objects + Response will be automatically GZIP compressed if size > 500 bytes + """ + try: + logger.info(LogFormatter.info("Fetching formatted leaderboard data")) + data = await leaderboard_service.get_formatted_data() + logger.info(LogFormatter.success(f"Retrieved {len(data)} formatted entries")) + return data + except Exception as e: + logger.error(LogFormatter.error("Failed to fetch formatted leaderboard data", e)) + raise \ No newline at end of file diff --git a/backend/app/api/endpoints/models.py b/backend/app/api/endpoints/models.py new file mode 100644 index 0000000000000000000000000000000000000000..f1872764bb8bbb37bce827e1c6fa50eda7cc4f8c --- /dev/null +++ b/backend/app/api/endpoints/models.py @@ -0,0 +1,116 @@ +from fastapi import APIRouter, HTTPException, Depends, Query +from typing import Dict, Any, List +import logging +from app.services.models import ModelService +from app.api.dependencies import get_model_service +from app.core.fastapi_cache import cached +from app.core.formatting import LogFormatter + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["models"]) + +@router.get("/status") +@cached(expire=300) +async def get_models_status( + model_service: ModelService = Depends(get_model_service) +) -> Dict[str, List[Dict[str, Any]]]: + """Get all models grouped by status""" + try: + logger.info(LogFormatter.info("Fetching status for all models")) + result = await model_service.get_models() + stats = { + status: len(models) for status, models in result.items() + } + for line in LogFormatter.stats(stats, "Models by Status"): + logger.info(line) + return result + except Exception as e: + logger.error(LogFormatter.error("Failed to get models status", e)) + raise HTTPException(status_code=500, detail=str(e)) + +@router.get("/pending") +@cached(expire=60) +async def get_pending_models( + model_service: ModelService = Depends(get_model_service) +) -> List[Dict[str, Any]]: + """Get all models waiting for evaluation""" + try: + logger.info(LogFormatter.info("Fetching pending models")) + models = await model_service.get_models() + pending = models.get("pending", []) + logger.info(LogFormatter.success(f"Found {len(pending)} pending models")) + return pending + except Exception as e: + logger.error(LogFormatter.error("Failed to get pending models", e)) + raise HTTPException(status_code=500, detail=str(e)) + +@router.post("/submit") +async def submit_model( + model_data: Dict[str, Any], + model_service: ModelService = Depends(get_model_service) +) -> Dict[str, Any]: + try: + logger.info(LogFormatter.section("MODEL SUBMISSION")) + + user_id = model_data.pop('user_id', None) + if not user_id: + error_msg = "user_id is required" + logger.error(LogFormatter.error("Validation failed", error_msg)) + raise ValueError(error_msg) + + # Log submission details + submission_info = { + "Model_ID": model_data.get("model_id"), + "User": user_id, + "Base_Model": model_data.get("base_model"), + "Precision": model_data.get("precision"), + "Model_Type": model_data.get("model_type") + } + for line in LogFormatter.tree(submission_info, "Submission Details"): + logger.info(line) + + result = await model_service.submit_model(model_data, user_id) + logger.info(LogFormatter.success("Model submitted successfully")) + return result + + except ValueError as e: + logger.error(LogFormatter.error("Invalid submission data", e)) + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(LogFormatter.error("Submission failed", e)) + raise HTTPException(status_code=500, detail=str(e)) + +@router.get("/organization/{organization}/submissions") +async def get_organization_submissions( + organization: str, + days: int = Query(default=7, ge=1, le=30), + model_service: ModelService = Depends(get_model_service) +) -> List[Dict[str, Any]]: + """Get all submissions from an organization in the last n days""" + try: + submissions = await model_service.get_organization_submissions(organization, days) + return submissions + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.get("/{model_id}/status") +async def get_model_status( + model_id: str, + model_service: ModelService = Depends(get_model_service) +) -> Dict[str, Any]: + try: + logger.info(LogFormatter.info(f"Checking status for model: {model_id}")) + status = await model_service.get_model_status(model_id) + + if status["status"] != "not_found": + logger.info(LogFormatter.success("Status found")) + for line in LogFormatter.tree(status, "Model Status"): + logger.info(line) + else: + logger.warning(LogFormatter.warning(f"No status found for model: {model_id}")) + + return status + + except Exception as e: + logger.error(LogFormatter.error("Failed to get model status", e)) + raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/app/api/endpoints/votes.py b/backend/app/api/endpoints/votes.py new file mode 100644 index 0000000000000000000000000000000000000000..bd97cee524c85792b610e3d1d3751af4d61484c7 --- /dev/null +++ b/backend/app/api/endpoints/votes.py @@ -0,0 +1,126 @@ +from fastapi import APIRouter, HTTPException, Query, Depends, Response +from typing import Dict, Any, List +from app.services.votes import VoteService +from app.core.fastapi_cache import cached, build_cache_key, invalidate_cache_key +import logging +from app.core.formatting import LogFormatter +from datetime import datetime, timezone + +logger = logging.getLogger(__name__) +router = APIRouter() +vote_service = VoteService() + +CACHE_TTL = 30 # 30 seconds cache + +def model_votes_key_builder(func, namespace: str = "model_votes", **kwargs): + """Build cache key for model votes""" + provider = kwargs.get('provider') + model = kwargs.get('model') + key = build_cache_key(namespace, provider, model) + logger.debug(LogFormatter.info(f"Built model votes cache key: {key}")) + return key + +def user_votes_key_builder(func, namespace: str = "user_votes", **kwargs): + """Build cache key for user votes""" + user_id = kwargs.get('user_id') + key = build_cache_key(namespace, user_id) + logger.debug(LogFormatter.info(f"Built user votes cache key: {key}")) + return key + +@router.post("/{model_id:path}") +async def add_vote( + response: Response, + model_id: str, + vote_type: str = Query(..., description="Type of vote (up/down)"), + user_id: str = Query(..., description="HuggingFace username"), + vote_data: Dict[str, Any] = None +) -> Dict[str, Any]: + try: + logger.info(LogFormatter.section("ADDING VOTE")) + stats = { + "Model": model_id, + "User": user_id, + "Type": vote_type, + "Config": vote_data or {} + } + for line in LogFormatter.tree(stats, "Vote Details"): + logger.info(line) + + await vote_service.initialize() + result = await vote_service.add_vote(model_id, user_id, vote_type, vote_data) + + # Invalidate affected caches + try: + logger.info(LogFormatter.subsection("CACHE INVALIDATION")) + provider, model = model_id.split('/', 1) + + # Build and invalidate cache keys + model_cache_key = build_cache_key("model_votes", provider, model) + user_cache_key = build_cache_key("user_votes", user_id) + + await invalidate_cache_key(model_cache_key) + await invalidate_cache_key(user_cache_key) + + cache_stats = { + "Model_Cache": model_cache_key, + "User_Cache": user_cache_key + } + for line in LogFormatter.tree(cache_stats, "Invalidated Caches"): + logger.info(line) + + except Exception as e: + logger.error(LogFormatter.error("Failed to invalidate cache", e)) + + # Add cache control headers + response.headers["Cache-Control"] = "no-cache" + + return result + except Exception as e: + logger.error(LogFormatter.error("Failed to add vote", e)) + raise HTTPException(status_code=400, detail=str(e)) + +@router.get("/model/{provider}/{model}") +@cached(expire=CACHE_TTL, key_builder=model_votes_key_builder) +async def get_model_votes( + response: Response, + provider: str, + model: str +) -> Dict[str, Any]: + """Get all votes for a specific model""" + try: + logger.info(LogFormatter.info(f"Fetching votes for model: {provider}/{model}")) + await vote_service.initialize() + model_id = f"{provider}/{model}" + result = await vote_service.get_model_votes(model_id) + + # Add cache control headers + response.headers["Cache-Control"] = f"max-age={CACHE_TTL}" + response.headers["Last-Modified"] = vote_service._last_sync.strftime("%a, %d %b %Y %H:%M:%S GMT") + + logger.info(LogFormatter.success(f"Found {result.get('total_votes', 0)} votes")) + return result + except Exception as e: + logger.error(LogFormatter.error("Failed to get model votes", e)) + raise HTTPException(status_code=400, detail=str(e)) + +@router.get("/user/{user_id}") +@cached(expire=CACHE_TTL, key_builder=user_votes_key_builder) +async def get_user_votes( + response: Response, + user_id: str +) -> List[Dict[str, Any]]: + """Get all votes from a specific user""" + try: + logger.info(LogFormatter.info(f"Fetching votes for user: {user_id}")) + await vote_service.initialize() + votes = await vote_service.get_user_votes(user_id) + + # Add cache control headers + response.headers["Cache-Control"] = f"max-age={CACHE_TTL}" + response.headers["Last-Modified"] = vote_service._last_sync.strftime("%a, %d %b %Y %H:%M:%S GMT") + + logger.info(LogFormatter.success(f"Found {len(votes)} votes")) + return votes + except Exception as e: + logger.error(LogFormatter.error("Failed to get user votes", e)) + raise HTTPException(status_code=400, detail=str(e)) \ No newline at end of file diff --git a/backend/app/api/router.py b/backend/app/api/router.py new file mode 100644 index 0000000000000000000000000000000000000000..4a07a694bb47359eecf023faf1c69b7b162e99f7 --- /dev/null +++ b/backend/app/api/router.py @@ -0,0 +1,9 @@ +from fastapi import APIRouter + +from app.api.endpoints import leaderboard, votes, models + +router = APIRouter() + +router.include_router(leaderboard.router, prefix="/leaderboard", tags=["leaderboard"]) +router.include_router(votes.router, prefix="/votes", tags=["votes"]) +router.include_router(models.router, prefix="/models", tags=["models"]) \ No newline at end of file diff --git a/backend/app/asgi.py b/backend/app/asgi.py new file mode 100644 index 0000000000000000000000000000000000000000..7b2b2dce0986eea205c6a547cad1af7f8c0958a7 --- /dev/null +++ b/backend/app/asgi.py @@ -0,0 +1,105 @@ +""" +ASGI entry point for the EEG Finetune Arena API. +""" +import os +import uvicorn +import logging +import logging.config +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.gzip import GZipMiddleware +import sys + +from app.api.router import router +from app.core.fastapi_cache import setup_cache +from app.core.formatting import LogFormatter +from app.config import hf_config + +# Configure logging before anything else +LOGGING_CONFIG = { + "version": 1, + "disable_existing_loggers": True, + "formatters": { + "default": { + "format": "%(name)s - %(levelname)s - %(message)s", + } + }, + "handlers": { + "default": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + } + }, + "loggers": { + "uvicorn": { + "handlers": ["default"], + "level": "WARNING", + "propagate": False, + }, + "uvicorn.error": { + "level": "WARNING", + "handlers": ["default"], + "propagate": False, + }, + "uvicorn.access": { + "handlers": ["default"], + "level": "WARNING", + "propagate": False, + }, + "app": { + "handlers": ["default"], + "level": "WARNING", + "propagate": False, + } + }, + "root": { + "handlers": ["default"], + "level": "WARNING", + } +} + +# Apply logging configuration +logging.config.dictConfig(LOGGING_CONFIG) +logger = logging.getLogger("app") + +# Create FastAPI application +app = FastAPI( + title="EEG Finetune Arena", + version="1.0.0", + docs_url="/docs", +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Add GZIP compression +app.add_middleware(GZipMiddleware, minimum_size=500) + +# Include API router +app.include_router(router, prefix="/api") + +@app.on_event("startup") +async def startup_event(): + """Initialize services on startup""" + logger.info("\n") + logger.info(LogFormatter.section("APPLICATION STARTUP")) + + # Log HF configuration + logger.info(LogFormatter.section("HUGGING FACE CONFIGURATION")) + logger.info(LogFormatter.info(f"Organization: {hf_config.HF_ORGANIZATION}")) + logger.info(LogFormatter.info(f"Token Status: {'Present' if hf_config.HF_TOKEN else 'Missing'}")) + logger.info(LogFormatter.info(f"Using repositories:")) + logger.info(LogFormatter.info(f" - Queue: {hf_config.QUEUE_REPO}")) + logger.info(LogFormatter.info(f" - Aggregated: {hf_config.AGGREGATED_REPO}")) + logger.info(LogFormatter.info(f" - Votes: {hf_config.VOTES_REPO}")) + + # Setup cache + setup_cache() + logger.info(LogFormatter.success("FastAPI Cache initialized with in-memory backend")) diff --git a/backend/app/config/__init__.py b/backend/app/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f451c4b1a4ed15c6c08710efc09d980c42d86d93 --- /dev/null +++ b/backend/app/config/__init__.py @@ -0,0 +1,6 @@ +""" +Configuration module for the EEG Finetune Arena backend. +All configuration values are imported from base.py to avoid circular dependencies. +""" + +from .base import * diff --git a/backend/app/config/base.py b/backend/app/config/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d541f7bdfcd7a2dc6c995f035e74bb52a56bff9a --- /dev/null +++ b/backend/app/config/base.py @@ -0,0 +1,38 @@ +import os +from pathlib import Path + +# Server configuration +HOST = "0.0.0.0" +PORT = 7860 +WORKERS = 4 +RELOAD = True if os.environ.get("ENVIRONMENT") == "development" else False + +# CORS configuration +ORIGINS = ["http://localhost:3000"] if os.getenv("ENVIRONMENT") == "development" else ["*"] + +# Cache configuration +CACHE_TTL = int(os.environ.get("CACHE_TTL", 300)) # 5 minutes default + +# Rate limiting +RATE_LIMIT_PERIOD = 7 # days +RATE_LIMIT_QUOTA = 5 +HAS_HIGHER_RATE_LIMIT = [] + +# HuggingFace configuration +HF_TOKEN = os.environ.get("HF_TOKEN") +HF_ORGANIZATION = "braindecode" +API = { + "INFERENCE": "https://api-inference.huggingface.co/models", + "HUB": "https://huggingface.co" +} + +# Cache paths +CACHE_ROOT = Path(os.environ.get("HF_HOME", ".cache")) +DATASETS_CACHE = CACHE_ROOT / "datasets" +MODELS_CACHE = CACHE_ROOT / "models" +VOTES_CACHE = CACHE_ROOT / "votes" +EVAL_CACHE = CACHE_ROOT / "eval-queue" + +# Repository configuration +QUEUE_REPO = f"{HF_ORGANIZATION}/requests" +EVAL_REQUESTS_PATH = EVAL_CACHE / "eval_requests.jsonl" diff --git a/backend/app/config/hf_config.py b/backend/app/config/hf_config.py new file mode 100644 index 0000000000000000000000000000000000000000..898e3131dff4bc4f9c0a15c7faed8ffc46a152f7 --- /dev/null +++ b/backend/app/config/hf_config.py @@ -0,0 +1,29 @@ +import os +import logging +from typing import Optional +from huggingface_hub import HfApi +from pathlib import Path +from app.core.cache import cache_config + +logger = logging.getLogger(__name__) + +# Organization or user who owns the datasets +HF_ORGANIZATION = "braindecode" + +# Get HF token directly from environment +HF_TOKEN = os.environ.get("HF_TOKEN") +if not HF_TOKEN: + logger.warning("HF_TOKEN not found in environment variables. Some features may be limited.") + +# Initialize HF API +API = HfApi(token=HF_TOKEN) + +# Repository configuration +QUEUE_REPO = f"{HF_ORGANIZATION}/requests" +AGGREGATED_REPO = f"{HF_ORGANIZATION}/contents" +VOTES_REPO = f"{HF_ORGANIZATION}/votes" + +# File paths from cache config +VOTES_PATH = cache_config.votes_file +EVAL_REQUESTS_PATH = cache_config.eval_requests_file +MODEL_CACHE_DIR = cache_config.models_cache diff --git a/backend/app/config/logging_config.py b/backend/app/config/logging_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ba6ad395ba9d0e181640a9c062772e6e84965330 --- /dev/null +++ b/backend/app/config/logging_config.py @@ -0,0 +1,38 @@ +import logging +import sys +from tqdm import tqdm + +def get_tqdm_handler(): + """ + Creates a special handler for tqdm that doesn't interfere with other logs. + """ + class TqdmLoggingHandler(logging.Handler): + def emit(self, record): + try: + msg = self.format(record) + tqdm.write(msg) + self.flush() + except Exception: + self.handleError(record) + + return TqdmLoggingHandler() + +def setup_service_logger(service_name: str) -> logging.Logger: + """ + Configure a specific logger for a given service. + """ + logger = logging.getLogger(f"app.services.{service_name}") + + # If the logger already has handlers, don't reconfigure it + if logger.handlers: + return logger + + # Add tqdm handler for this service + tqdm_handler = get_tqdm_handler() + tqdm_handler.setFormatter(logging.Formatter('%(name)s - %(levelname)s - %(message)s')) + logger.addHandler(tqdm_handler) + + # Don't propagate logs to parent loggers + logger.propagate = False + + return logger \ No newline at end of file diff --git a/backend/app/core/cache.py b/backend/app/core/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..8bdec85cc78cbebe1bc40ce48bb323b0cc001ee7 --- /dev/null +++ b/backend/app/core/cache.py @@ -0,0 +1,109 @@ +import os +import shutil +from pathlib import Path +from datetime import timedelta +import logging +from app.core.formatting import LogFormatter +from app.config.base import ( + CACHE_ROOT, + DATASETS_CACHE, + MODELS_CACHE, + VOTES_CACHE, + EVAL_CACHE, + CACHE_TTL +) + +logger = logging.getLogger(__name__) + +class CacheConfig: + def __init__(self): + # Get cache paths from config + self.cache_root = CACHE_ROOT + self.datasets_cache = DATASETS_CACHE + self.models_cache = MODELS_CACHE + self.votes_cache = VOTES_CACHE + self.eval_cache = EVAL_CACHE + + # Specific files + self.votes_file = self.votes_cache / "votes_data.jsonl" + self.eval_requests_file = self.eval_cache / "eval_requests.jsonl" + + # Cache TTL + self.cache_ttl = timedelta(seconds=CACHE_TTL) + + self._initialize_cache_dirs() + self._setup_environment() + + def _initialize_cache_dirs(self): + """Initialize all necessary cache directories""" + try: + logger.info(LogFormatter.section("CACHE INITIALIZATION")) + + cache_dirs = { + "Root": self.cache_root, + "Datasets": self.datasets_cache, + "Models": self.models_cache, + "Votes": self.votes_cache, + "Eval": self.eval_cache + } + + for name, cache_dir in cache_dirs.items(): + cache_dir.mkdir(parents=True, exist_ok=True) + logger.info(LogFormatter.success(f"{name} cache directory: {cache_dir}")) + + except Exception as e: + logger.error(LogFormatter.error("Failed to create cache directories", e)) + raise + + def _setup_environment(self): + """Configure HuggingFace environment variables""" + logger.info(LogFormatter.subsection("ENVIRONMENT SETUP")) + + env_vars = { + "HF_HOME": str(self.cache_root), + "HF_DATASETS_CACHE": str(self.datasets_cache) + } + + for var, value in env_vars.items(): + os.environ[var] = value + logger.info(LogFormatter.info(f"Set {var}={value}")) + + + def get_cache_path(self, cache_type: str) -> Path: + """Returns the path for a specific cache type""" + cache_paths = { + "datasets": self.datasets_cache, + "models": self.models_cache, + "votes": self.votes_cache, + "eval": self.eval_cache + } + return cache_paths.get(cache_type, self.cache_root) + + def flush_cache(self, cache_type: str = None): + """Flush specified cache or all caches if no type is specified""" + try: + if cache_type: + logger.info(LogFormatter.section(f"FLUSHING {cache_type.upper()} CACHE")) + cache_dir = self.get_cache_path(cache_type) + if cache_dir.exists(): + stats = { + "Cache_Type": cache_type, + "Directory": str(cache_dir) + } + for line in LogFormatter.tree(stats, "Cache Details"): + logger.info(line) + shutil.rmtree(cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + logger.info(LogFormatter.success("Cache cleared successfully")) + else: + logger.info(LogFormatter.section("FLUSHING ALL CACHES")) + for cache_type in ["datasets", "models", "votes", "eval"]: + self.flush_cache(cache_type) + logger.info(LogFormatter.success("All caches cleared successfully")) + + except Exception as e: + logger.error(LogFormatter.error("Failed to flush cache", e)) + raise + +# Singleton instance of cache configuration +cache_config = CacheConfig() \ No newline at end of file diff --git a/backend/app/core/fastapi_cache.py b/backend/app/core/fastapi_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..455d5de235bc1c17e8b465b14637ffbd70f92ddc --- /dev/null +++ b/backend/app/core/fastapi_cache.py @@ -0,0 +1,76 @@ +from fastapi_cache import FastAPICache +from fastapi_cache.backends.inmemory import InMemoryBackend +from fastapi_cache.decorator import cache +from datetime import timedelta +from app.config import CACHE_TTL +import logging +from app.core.formatting import LogFormatter +from typing import Optional, Any + +logger = logging.getLogger(__name__) + +class CustomInMemoryBackend(InMemoryBackend): + def __init__(self): + """Initialize the cache backend""" + super().__init__() + self.cache = {} + + async def delete(self, key: str) -> bool: + """Delete a key from the cache""" + try: + if key in self.cache: + del self.cache[key] + return True + return False + except Exception as e: + logger.error(LogFormatter.error(f"Failed to delete key {key} from cache", e)) + return False + + async def get(self, key: str) -> Any: + """Get a value from the cache""" + return self.cache.get(key) + + async def set(self, key: str, value: Any, expire: Optional[int] = None) -> None: + """Set a value in the cache""" + self.cache[key] = value + +def setup_cache(): + """Initialize FastAPI Cache with in-memory backend""" + try: + logger.info(LogFormatter.section("CACHE INITIALIZATION")) + FastAPICache.init( + backend=CustomInMemoryBackend(), + prefix="fastapi-cache" + ) + logger.info(LogFormatter.success("Cache initialized successfully")) + except Exception as e: + logger.error(LogFormatter.error("Failed to initialize cache", e)) + raise + +async def invalidate_cache_key(key: str): + """Invalidate a specific cache key""" + try: + backend = FastAPICache.get_backend() + if hasattr(backend, 'delete'): + await backend.delete(key) + logger.info(LogFormatter.success(f"Cache invalidated for key: {key}")) + else: + logger.warning(LogFormatter.warning("Cache backend does not support deletion")) + except Exception as e: + logger.error(LogFormatter.error(f"Failed to invalidate cache key: {key}", e)) + +def build_cache_key(*args) -> str: + """Build a cache key from multiple arguments""" + return ":".join(str(arg) for arg in args if arg is not None) + +def cached(expire: int = CACHE_TTL, key_builder=None): + """Decorator for caching endpoint responses + + Args: + expire (int): Cache TTL in seconds + key_builder (callable, optional): Custom key builder function + """ + return cache( + expire=expire, + key_builder=key_builder + ) \ No newline at end of file diff --git a/backend/app/core/formatting.py b/backend/app/core/formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..e76432f93ff8213479e76a50f65498a09910912b --- /dev/null +++ b/backend/app/core/formatting.py @@ -0,0 +1,104 @@ +import logging +from typing import Dict, Any, List, Optional + +logger = logging.getLogger(__name__) + +class LogFormatter: + """Utility class for consistent log formatting across the application""" + + @staticmethod + def section(title: str) -> str: + """Create a section header""" + return f"\n{'='*20} {title.upper()} {'='*20}" + + @staticmethod + def subsection(title: str) -> str: + """Create a subsection header""" + return f"\n{'─'*20} {title} {'─'*20}" + + @staticmethod + def tree(items: Dict[str, Any], title: str = None) -> List[str]: + """Create a tree view of dictionary data""" + lines = [] + if title: + lines.append(f"📊 {title}:") + + # Get the maximum length for alignment + max_key_length = max(len(str(k)) for k in items.keys()) + + # Format each item + for i, (key, value) in enumerate(items.items()): + prefix = "└──" if i == len(items) - 1 else "├──" + if isinstance(value, (int, float)): + value = f"{value:,}" # Add thousand separators + lines.append(f"{prefix} {str(key):<{max_key_length}}: {value}") + + return lines + + @staticmethod + def stats(stats: Dict[str, int], title: str = None) -> List[str]: + """Format statistics with icons""" + lines = [] + if title: + lines.append(f"📊 {title}:") + + # Get the maximum length for alignment + max_key_length = max(len(str(k)) for k in stats.keys()) + + # Format each stat with an appropriate icon + icons = { + "total": "📌", + "success": "✅", + "error": "❌", + "pending": "⏳", + "processing": "⚙️", + "finished": "✨", + "evaluating": "🔄", + "downloads": "⬇️", + "files": "📁", + "cached": "💾", + "size": "📏", + "time": "⏱️", + "rate": "🚀" + } + + # Format each item + for i, (key, value) in enumerate(stats.items()): + prefix = "└──" if i == len(stats) - 1 else "├──" + icon = icons.get(key.lower().split('_')[0], "•") + if isinstance(value, (int, float)): + value = f"{value:,}" # Add thousand separators + lines.append(f"{prefix} {icon} {str(key):<{max_key_length}}: {value}") + + return lines + + @staticmethod + def progress_bar(current: int, total: int, width: int = 20) -> str: + """Create a progress bar""" + percentage = (current * 100) // total + filled = "█" * (percentage * width // 100) + empty = "░" * (width - len(filled)) + return f"{filled}{empty} {percentage:3d}%" + + @staticmethod + def error(message: str, error: Optional[Exception] = None) -> str: + """Format error message""" + error_msg = f"\n❌ Error: {message}" + if error: + error_msg += f"\n └── Details: {str(error)}" + return error_msg + + @staticmethod + def success(message: str) -> str: + """Format success message""" + return f"✅ {message}" + + @staticmethod + def warning(message: str) -> str: + """Format warning message""" + return f"⚠️ {message}" + + @staticmethod + def info(message: str) -> str: + """Format info message""" + return f"ℹ️ {message}" \ No newline at end of file diff --git a/backend/app/main.py b/backend/app/main.py new file mode 100644 index 0000000000000000000000000000000000000000..9540d2a86e9f298158464e7935c5af387dd4b3c9 --- /dev/null +++ b/backend/app/main.py @@ -0,0 +1,18 @@ +from fastapi import FastAPI +from app.config.logging_config import setup_logging +import logging + +# Initialize logging configuration +setup_logging() +logger = logging.getLogger(__name__) + +app = FastAPI(title="Open LLM Leaderboard API") + +@app.on_event("startup") +async def startup_event(): + logger.info("Starting up the application...") + +# Import and include routers after app initialization +from app.api import models, votes +app.include_router(models.router, prefix="/api", tags=["models"]) +app.include_router(votes.router, prefix="/api", tags=["votes"]) diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..399192f82143e7bf446fa183fa9e7779adab2bd7 --- /dev/null +++ b/backend/app/services/__init__.py @@ -0,0 +1,3 @@ +from . import hf_service, leaderboard, votes, models + +__all__ = ["hf_service", "leaderboard", "votes", "models"] diff --git a/backend/app/services/hf_service.py b/backend/app/services/hf_service.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e8dade57ced4daa6ed5c36a1a830325f1f264a --- /dev/null +++ b/backend/app/services/hf_service.py @@ -0,0 +1,50 @@ +from typing import Optional +from huggingface_hub import HfApi +from app.config import HF_TOKEN, API +from app.core.cache import cache_config +from app.core.formatting import LogFormatter +import logging + +logger = logging.getLogger(__name__) + +class HuggingFaceService: + def __init__(self): + self.api = API + self.token = HF_TOKEN + self.cache_dir = cache_config.models_cache + + async def check_authentication(self) -> bool: + """Check if the HF token is valid""" + if not self.token: + return False + try: + logger.info(LogFormatter.info("Checking HF token validity...")) + self.api.get_token_permission() + logger.info(LogFormatter.success("HF token is valid")) + return True + except Exception as e: + logger.error(LogFormatter.error("HF token validation failed", e)) + return False + + async def get_user_info(self) -> Optional[dict]: + """Get information about the authenticated user""" + try: + logger.info(LogFormatter.info("Fetching user information...")) + info = self.api.get_token_permission() + logger.info(LogFormatter.success(f"User info retrieved for: {info.get('user', 'Unknown')}")) + return info + except Exception as e: + logger.error(LogFormatter.error("Failed to get user info", e)) + return None + + def _log_repo_operation(self, operation: str, repo: str, details: str = None): + """Helper to log repository operations""" + logger.info(LogFormatter.section(f"HF REPOSITORY OPERATION - {operation.upper()}")) + stats = { + "Operation": operation, + "Repository": repo, + } + if details: + stats["Details"] = details + for line in LogFormatter.tree(stats): + logger.info(line) \ No newline at end of file diff --git a/backend/app/services/leaderboard.py b/backend/app/services/leaderboard.py new file mode 100644 index 0000000000000000000000000000000000000000..63c0426802433e7a816d170fe21766cdb3ff20c9 --- /dev/null +++ b/backend/app/services/leaderboard.py @@ -0,0 +1,184 @@ +from app.core.cache import cache_config +from datetime import datetime +from typing import List, Dict, Any +import datasets +from fastapi import HTTPException +import logging +from app.config.base import HF_ORGANIZATION +from app.core.formatting import LogFormatter + +logger = logging.getLogger(__name__) + +class LeaderboardService: + def __init__(self): + pass + + async def fetch_raw_data(self) -> List[Dict[str, Any]]: + """Fetch raw leaderboard data from HuggingFace dataset""" + try: + logger.info(LogFormatter.section("FETCHING LEADERBOARD DATA")) + logger.info(LogFormatter.info(f"Loading dataset from {HF_ORGANIZATION}/contents")) + + dataset = datasets.load_dataset( + f"{HF_ORGANIZATION}/contents", + cache_dir=cache_config.get_cache_path("datasets") + )["train"] + + df = dataset.to_pandas() + data = df.to_dict('records') + + stats = { + "Total_Entries": len(data), + "Dataset_Size": f"{df.memory_usage(deep=True).sum() / 1024 / 1024:.1f}MB" + } + for line in LogFormatter.stats(stats, "Dataset Statistics"): + logger.info(line) + + return data + + except Exception as e: + logger.error(LogFormatter.error("Failed to fetch leaderboard data", e)) + raise HTTPException(status_code=500, detail=str(e)) + + async def get_formatted_data(self) -> List[Dict[str, Any]]: + """Get formatted leaderboard data""" + try: + logger.info(LogFormatter.section("FORMATTING LEADERBOARD DATA")) + + raw_data = await self.fetch_raw_data() + formatted_data = [] + type_counts = {} + error_count = 0 + + # Initialize progress tracking + total_items = len(raw_data) + logger.info(LogFormatter.info(f"Processing {total_items:,} entries...")) + + for i, item in enumerate(raw_data, 1): + try: + formatted_item = await self.transform_data(item) + formatted_data.append(formatted_item) + + # Count model types + model_type = formatted_item["model"]["type"] + type_counts[model_type] = type_counts.get(model_type, 0) + 1 + + except Exception as e: + error_count += 1 + logger.error(LogFormatter.error(f"Failed to format entry {i}/{total_items}", e)) + continue + + # Log progress every 10% + if i % max(1, total_items // 10) == 0: + progress = (i / total_items) * 100 + logger.info(LogFormatter.info(f"Progress: {LogFormatter.progress_bar(i, total_items)}")) + + # Log final statistics + stats = { + "Total_Processed": total_items, + "Successful": len(formatted_data), + "Failed": error_count + } + logger.info(LogFormatter.section("PROCESSING SUMMARY")) + for line in LogFormatter.stats(stats, "Processing Statistics"): + logger.info(line) + + # Log model type distribution + type_stats = {f"Type_{k}": v for k, v in type_counts.items()} + logger.info(LogFormatter.subsection("MODEL TYPE DISTRIBUTION")) + for line in LogFormatter.stats(type_stats): + logger.info(line) + + return formatted_data + + except Exception as e: + logger.error(LogFormatter.error("Failed to format leaderboard data", e)) + raise HTTPException(status_code=500, detail=str(e)) + + async def transform_data(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Transform raw data into the format expected by the frontend""" + try: + # Extract model name for logging + model_name = data.get("fullname", "Unknown") + logger.debug(LogFormatter.info(f"Transforming data for model: {model_name}")) + + # Create unique ID combining model name, precision and sha + unique_id = f"{data.get('fullname', 'Unknown')}_{data.get('Precision', 'Unknown')}_{data.get('Model sha', 'Unknown')}" + + # EEG benchmark evaluations (placeholders matching current dataset schema) + evaluations = { + "anli": { + "name": "ANLI", + "value": data.get("ANLI Raw", 0), + "normalized_score": data.get("ANLI", 0) + }, + "logiqa": { + "name": "LogiQA", + "value": data.get("LogiQA Raw", 0), + "normalized_score": data.get("LogiQA", 0) + } + } + + features = { + "is_not_available_on_hub": data.get("Available on the hub", False), + } + + metadata = { + "upload_date": data.get("Upload To Hub Date"), + "submission_date": data.get("Submission Date"), + "generation": data.get("Generation"), + "base_model": data.get("Base Model"), + "hub_license": data.get("Hub License"), + "hub_hearts": data.get("Hub ❤️"), + "params_millions": data.get("#Params (M)"), + } + + # Clean model type by removing emojis if present + original_type = data.get("Type", "") + model_type = original_type.lower().strip() + + # Remove emojis and parentheses + if "(" in model_type: + model_type = model_type.split("(")[0].strip() + model_type = ''.join(c for c in model_type if not c in '🟢🔶🧪🏗️ ') + + # Map model types for EEG domain + model_type_mapping = { + "fine-tuned": "fine-tuned", + "fine tuned": "fine-tuned", + "finetuned": "fine-tuned", + "fine_tuned": "fine-tuned", + "ft": "fine-tuned", + "pretrained": "pretrained", + "pre-trained": "pretrained", + "task-specific": "task-specific", + "foundation": "foundation", + } + + mapped_type = model_type_mapping.get(model_type.lower().strip(), model_type) + + if mapped_type != model_type: + logger.debug(LogFormatter.info(f"Model type mapped: {original_type} -> {mapped_type}")) + + transformed_data = { + "id": unique_id, + "model": { + "name": data.get("fullname"), + "sha": data.get("Model sha"), + "precision": data.get("Precision"), + "type": mapped_type, + "weight_type": data.get("Weight type"), + "architecture": data.get("Architecture"), + "average_score": data.get("Average ⬆️"), + }, + "evaluations": evaluations, + "features": features, + "metadata": metadata + } + + logger.debug(LogFormatter.success(f"Successfully transformed data for {model_name}")) + return transformed_data + + except Exception as e: + logger.error(LogFormatter.error(f"Failed to transform data for {data.get('fullname', 'Unknown')}", e)) + raise diff --git a/backend/app/services/models.py b/backend/app/services/models.py new file mode 100644 index 0000000000000000000000000000000000000000..cd302185c9fa86e21dcc75f654ec0331ab9bd0ab --- /dev/null +++ b/backend/app/services/models.py @@ -0,0 +1,574 @@ +from datetime import datetime, timezone, timedelta +from typing import Dict, Any, Optional, List +import json +import os +from pathlib import Path +import logging +import aiohttp +import asyncio +import time +from huggingface_hub import HfApi, CommitOperationAdd +from huggingface_hub.utils import build_hf_headers +from datasets import disable_progress_bar +import sys +import contextlib +from concurrent.futures import ThreadPoolExecutor +import tempfile + +from app.config import ( + QUEUE_REPO, + HF_TOKEN, + EVAL_REQUESTS_PATH +) +from app.config.hf_config import HF_ORGANIZATION +from app.services.hf_service import HuggingFaceService +from app.utils.model_validation import ModelValidator +from app.services.votes import VoteService +from app.core.cache import cache_config +from app.core.formatting import LogFormatter + +# Disable datasets progress bars globally +disable_progress_bar() + +logger = logging.getLogger(__name__) + +# Context manager to temporarily disable stdout and stderr +@contextlib.contextmanager +def suppress_output(): + stdout = sys.stdout + stderr = sys.stderr + devnull = open(os.devnull, 'w') + try: + sys.stdout = devnull + sys.stderr = devnull + yield + finally: + sys.stdout = stdout + sys.stderr = stderr + devnull.close() + +class ProgressTracker: + def __init__(self, total: int, desc: str = "Progress", update_frequency: int = 10): + self.total = total + self.current = 0 + self.desc = desc + self.start_time = time.time() + self.update_frequency = update_frequency # Percentage steps + self.last_update = -1 + + # Initial log with fancy formatting + logger.info(LogFormatter.section(desc)) + logger.info(LogFormatter.info(f"Starting processing of {total:,} items...")) + sys.stdout.flush() + + def update(self, n: int = 1): + self.current += n + current_percentage = (self.current * 100) // self.total + + # Only update on frequency steps (e.g., 0%, 10%, 20%, etc.) + if current_percentage >= self.last_update + self.update_frequency or current_percentage == 100: + elapsed = time.time() - self.start_time + rate = self.current / elapsed if elapsed > 0 else 0 + remaining = (self.total - self.current) / rate if rate > 0 else 0 + + # Create progress stats + stats = { + "Progress": LogFormatter.progress_bar(self.current, self.total), + "Items": f"{self.current:,}/{self.total:,}", + "Time": f"⏱️ {elapsed:.1f}s elapsed, {remaining:.1f}s remaining", + "Rate": f"🚀 {rate:.1f} items/s" + } + + # Log progress using tree format + for line in LogFormatter.tree(stats): + logger.info(line) + sys.stdout.flush() + + self.last_update = (current_percentage // self.update_frequency) * self.update_frequency + + def close(self): + elapsed = time.time() - self.start_time + rate = self.total / elapsed if elapsed > 0 else 0 + + # Final summary with fancy formatting + logger.info(LogFormatter.section("COMPLETED")) + stats = { + "Total": f"{self.total:,} items", + "Time": f"{elapsed:.1f}s", + "Rate": f"{rate:.1f} items/s" + } + for line in LogFormatter.stats(stats): + logger.info(line) + logger.info("="*50) + sys.stdout.flush() + +class ModelService(HuggingFaceService): + _instance: Optional['ModelService'] = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + logger.info(LogFormatter.info("Creating new ModelService instance")) + cls._instance = super(ModelService, cls).__new__(cls) + return cls._instance + + def __init__(self): + if not hasattr(self, '_init_done'): + logger.info(LogFormatter.section("MODEL SERVICE INITIALIZATION")) + super().__init__() + self.validator = ModelValidator() + self.vote_service = VoteService() + self.eval_requests_path = cache_config.eval_requests_file + logger.info(LogFormatter.info(f"Using eval requests path: {self.eval_requests_path}")) + + self.eval_requests_path.parent.mkdir(parents=True, exist_ok=True) + self.hf_api = HfApi(token=HF_TOKEN) + self.cached_models = None + self.last_cache_update = 0 + self.cache_ttl = cache_config.cache_ttl.total_seconds() + self._init_done = True + logger.info(LogFormatter.success("Initialization complete")) + + async def _refresh_models_cache(self): + """Refresh the models cache""" + try: + logger.info(LogFormatter.section("CACHE REFRESH")) + self._log_repo_operation("read", f"{HF_ORGANIZATION}/requests", "Refreshing models cache") + + # Initialize models dictionary + models = { + "finished": [], + "evaluating": [], + "pending": [] + } + + try: + logger.info(LogFormatter.subsection("DATASET LOADING")) + logger.info(LogFormatter.info("Loading dataset...")) + + # Download entire dataset snapshot + with suppress_output(): + local_dir = self.hf_api.snapshot_download( + repo_id=QUEUE_REPO, + repo_type="dataset", + token=self.token + ) + + # List JSON files in local directory + local_path = Path(local_dir) + json_files = list(local_path.glob("**/*.json")) + total_files = len(json_files) + + # Log repository stats + stats = { + "Total_Files": total_files, + "Local_Dir": str(local_path), + } + for line in LogFormatter.stats(stats, "Repository Statistics"): + logger.info(line) + + if not json_files: + raise Exception("No JSON files found in repository") + + # Initialize progress tracker + progress = ProgressTracker(total_files, "PROCESSING FILES") + + # Process local files + model_submissions = {} + for file_path in json_files: + try: + with open(file_path, 'r') as f: + content = json.load(f) + + # Get status and determine target status + status = content.get("status", "PENDING").upper() + target_status = None + status_map = { + "PENDING": ["PENDING"], + "EVALUATING": ["RUNNING"], + "FINISHED": ["FINISHED"] + } + + for target, source_statuses in status_map.items(): + if status in source_statuses: + target_status = target + break + + if not target_status: + progress.update() + continue + + # Calculate wait time + try: + submit_time = datetime.fromisoformat(content["submitted_time"].replace("Z", "+00:00")) + if submit_time.tzinfo is None: + submit_time = submit_time.replace(tzinfo=timezone.utc) + current_time = datetime.now(timezone.utc) + wait_time = current_time - submit_time + + model_info = { + "name": content["model"], + "submitter": content.get("sender", "Unknown"), + "revision": content["revision"], + "wait_time": f"{wait_time.total_seconds():.1f}s", + "submission_time": content["submitted_time"], + "status": target_status, + "precision": content.get("precision", "Unknown") + } + + # Use (model_id, revision, precision) as key to track latest submission + key = (content["model"], content["revision"], content.get("precision", "Unknown")) + if key not in model_submissions or submit_time > datetime.fromisoformat(model_submissions[key]["submission_time"].replace("Z", "+00:00")): + model_submissions[key] = model_info + + except (ValueError, TypeError) as e: + logger.error(LogFormatter.error(f"Failed to process {file_path.name}", e)) + + except Exception as e: + logger.error(LogFormatter.error(f"Failed to load {file_path.name}", e)) + finally: + progress.update() + + # Populate models dict with deduplicated submissions + for model_info in model_submissions.values(): + models[model_info["status"].lower()].append(model_info) + + progress.close() + + # Final summary with fancy formatting + logger.info(LogFormatter.section("CACHE SUMMARY")) + stats = { + "Finished": len(models["finished"]), + "Evaluating": len(models["evaluating"]), + "Pending": len(models["pending"]) + } + for line in LogFormatter.stats(stats, "Models by Status"): + logger.info(line) + logger.info("="*50) + + except Exception as e: + logger.error(LogFormatter.error("Error processing files", e)) + raise + + # Update cache + self.cached_models = models + self.last_cache_update = time.time() + logger.info(LogFormatter.success("Cache updated successfully")) + + return models + + except Exception as e: + logger.error(LogFormatter.error("Cache refresh failed", e)) + raise + + async def initialize(self): + """Initialize the model service""" + if self._initialized: + logger.info(LogFormatter.info("Service already initialized, using cached data")) + return + + try: + logger.info(LogFormatter.section("MODEL SERVICE INITIALIZATION")) + + # Check if cache already exists + cache_path = cache_config.get_cache_path("datasets") + if not cache_path.exists() or not any(cache_path.iterdir()): + logger.info(LogFormatter.info("No existing cache found, initializing datasets cache...")) + cache_config.flush_cache("datasets") + else: + logger.info(LogFormatter.info("Using existing datasets cache")) + + # Ensure eval requests directory exists + self.eval_requests_path.parent.mkdir(parents=True, exist_ok=True) + logger.info(LogFormatter.info(f"Eval requests directory: {self.eval_requests_path}")) + + # List existing files + if self.eval_requests_path.exists(): + files = list(self.eval_requests_path.glob("**/*.json")) + stats = { + "Total_Files": len(files), + "Directory": str(self.eval_requests_path) + } + for line in LogFormatter.stats(stats, "Eval Requests"): + logger.info(line) + + # Load initial cache + await self._refresh_models_cache() + + self._initialized = True + logger.info(LogFormatter.success("Model service initialization complete")) + + except Exception as e: + logger.error(LogFormatter.error("Initialization failed", e)) + raise + + async def get_models(self) -> Dict[str, List[Dict[str, Any]]]: + """Get all models with their status""" + if not self._initialized: + logger.info(LogFormatter.info("Service not initialized, initializing now...")) + await self.initialize() + + current_time = time.time() + cache_age = current_time - self.last_cache_update + + # Check if cache needs refresh + if not self.cached_models: + logger.info(LogFormatter.info("No cached data available, refreshing cache...")) + return await self._refresh_models_cache() + elif cache_age > self.cache_ttl: + logger.info(LogFormatter.info(f"Cache expired ({cache_age:.1f}s old, TTL: {self.cache_ttl}s)")) + return await self._refresh_models_cache() + else: + logger.info(LogFormatter.info(f"Using cached data ({cache_age:.1f}s old)")) + return self.cached_models + + async def submit_model( + self, + model_data: Dict[str, Any], + user_id: str + ) -> Dict[str, Any]: + logger.info(LogFormatter.section("MODEL SUBMISSION")) + self._log_repo_operation("write", f"{HF_ORGANIZATION}/requests", f"Submitting model {model_data['model_id']} by {user_id}") + stats = { + "Model": model_data["model_id"], + "User": user_id, + "Revision": model_data["revision"], + "Precision": model_data["precision"], + "Type": model_data["model_type"] + } + for line in LogFormatter.tree(stats, "Submission Details"): + logger.info(line) + + # Validate required fields + required_fields = [ + "model_id", "base_model", "revision", "precision", + "weight_type", "model_type" + ] + for field in required_fields: + if field not in model_data: + raise ValueError(f"Missing required field: {field}") + + # Get model info and validate it exists on HuggingFace + try: + logger.info(LogFormatter.subsection("MODEL VALIDATION")) + + # Get the model info to check if it exists + model_info = self.hf_api.model_info( + model_data["model_id"], + revision=model_data["revision"], + token=self.token + ) + + if not model_info: + raise Exception(f"Model {model_data['model_id']} not found on HuggingFace Hub") + + logger.info(LogFormatter.success("Model exists on HuggingFace Hub")) + + except Exception as e: + logger.error(LogFormatter.error("Model validation failed", e)) + raise + + # Update model revision with commit sha + model_data["revision"] = model_info.sha + + # Check if model already exists in the system + try: + logger.info(LogFormatter.subsection("CHECKING EXISTING SUBMISSIONS")) + existing_models = await self.get_models() + + # Check in all statuses (pending, evaluating, finished) + for status, models in existing_models.items(): + for model in models: + if model["name"] == model_data["model_id"] and model["revision"] == model_data["revision"]: + error_msg = f"Model {model_data['model_id']} revision {model_data['revision']} is already in the system with status: {status}" + logger.error(LogFormatter.error("Submission rejected", error_msg)) + raise ValueError(error_msg) + + logger.info(LogFormatter.success("No existing submission found")) + except ValueError: + raise + except Exception as e: + logger.error(LogFormatter.error("Failed to check existing submissions", e)) + raise + + # Check that model on hub and valid + valid, error, model_config = await self.validator.is_model_on_hub( + model_data["model_id"], + model_data["revision"], + test_tokenizer=False + ) + if not valid: + logger.error(LogFormatter.error("Model on hub validation failed", error)) + raise Exception(error) + logger.info(LogFormatter.success("Model on hub validation passed")) + + # Validate model card + valid, error, model_card = await self.validator.check_model_card( + model_data["model_id"] + ) + if not valid: + logger.error(LogFormatter.error("Model card validation failed", error)) + raise Exception(error) + logger.info(LogFormatter.success("Model card validation passed")) + + # Check size limits + model_size, error = await self.validator.get_model_size( + model_info, + model_data["precision"], + model_data["base_model"], + revision=model_data["revision"] + ) + if model_size is None: + logger.error(LogFormatter.error("Model size validation failed", error)) + raise Exception(error) + logger.info(LogFormatter.success(f"Model size validation passed: {model_size:.1f}B")) + + # Size limits for EEG models (much smaller than LLMs - limit at 10B params) + if model_data["precision"] in ["float16", "bfloat16"] and model_size > 10: + error_msg = f"Model too large for {model_data['precision']} (limit: 10B)" + logger.error(LogFormatter.error("Size limit exceeded", error_msg)) + raise Exception(error_msg) + + architectures = model_info.config.get("architectures", "") + if architectures: + architectures = ";".join(architectures) + + # Create eval entry + eval_entry = { + "model": model_data["model_id"], + "base_model": model_data["base_model"], + "revision": model_info.sha, + "precision": model_data["precision"], + "params": model_size, + "architectures": architectures, + "weight_type": model_data["weight_type"], + "status": "PENDING", + "submitted_time": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), + "model_type": model_data["model_type"], + "job_id": -1, + "job_start_time": None, + "sender": user_id + } + + logger.info(LogFormatter.subsection("EVALUATION ENTRY")) + for line in LogFormatter.tree(eval_entry): + logger.info(line) + + # Upload to HF dataset + try: + logger.info(LogFormatter.subsection("UPLOADING TO HUGGINGFACE")) + logger.info(LogFormatter.info(f"Uploading to {HF_ORGANIZATION}/requests...")) + + # Construct the path in the dataset + org_or_user = model_data["model_id"].split("/")[0] if "/" in model_data["model_id"] else "" + model_path = model_data["model_id"].split("/")[-1] + relative_path = f"{org_or_user}/{model_path}_eval_request_False_{model_data['precision']}_{model_data['weight_type']}.json" + + # Create a temporary file with the request + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as temp_file: + json.dump(eval_entry, temp_file, indent=2) + temp_file.flush() + temp_path = temp_file.name + + # Upload file directly + self.hf_api.upload_file( + path_or_fileobj=temp_path, + path_in_repo=relative_path, + repo_id=f"{HF_ORGANIZATION}/requests", + repo_type="dataset", + commit_message=f"Add {model_data['model_id']} to eval queue", + token=self.token + ) + + # Clean up temp file + os.unlink(temp_path) + + logger.info(LogFormatter.success("Upload successful")) + + except Exception as e: + logger.error(LogFormatter.error("Upload failed", e)) + raise + + # Add automatic vote + try: + logger.info(LogFormatter.subsection("AUTOMATIC VOTE")) + logger.info(LogFormatter.info(f"Adding upvote for {model_data['model_id']} by {user_id}")) + await self.vote_service.add_vote( + model_data["model_id"], + user_id, + "up", + { + "precision": model_data["precision"], + "revision": model_data["revision"] + } + ) + logger.info(LogFormatter.success("Vote recorded successfully")) + except Exception as e: + logger.error(LogFormatter.error("Failed to record vote", e)) + # Don't raise here as the main submission was successful + + return { + "status": "success", + "message": "The model was submitted successfully, and the vote has been recorded" + } + + async def get_model_status(self, model_id: str) -> Dict[str, Any]: + """Get evaluation status of a model""" + logger.info(LogFormatter.info(f"Checking status for model: {model_id}")) + eval_path = self.eval_requests_path + + for user_folder in eval_path.iterdir(): + if user_folder.is_dir(): + for file in user_folder.glob("*.json"): + with open(file, "r") as f: + data = json.load(f) + if data["model"] == model_id: + status = { + "status": data["status"], + "submitted_time": data["submitted_time"], + "job_id": data.get("job_id", -1) + } + logger.info(LogFormatter.success("Status found")) + for line in LogFormatter.tree(status, "Model Status"): + logger.info(line) + return status + + logger.warning(LogFormatter.warning(f"No status found for model: {model_id}")) + return {"status": "not_found"} + + async def get_organization_submissions(self, organization: str, days: int = 7) -> List[Dict[str, Any]]: + """Get all submissions from a user in the last n days""" + try: + # Get all models + all_models = await self.get_models() + current_time = datetime.now(timezone.utc) + cutoff_time = current_time - timedelta(days=days) + + # Filter models by submitter and submission time + user_submissions = [] + for status, models in all_models.items(): + for model in models: + # Check if model was submitted by the user + if model["submitter"] == organization: + # Parse submission time + submit_time = datetime.fromisoformat( + model["submission_time"].replace("Z", "+00:00") + ) + # Check if within time window + if submit_time > cutoff_time: + user_submissions.append({ + "name": model["name"], + "status": status, + "submission_time": model["submission_time"], + "precision": model["precision"] + }) + + return sorted( + user_submissions, + key=lambda x: x["submission_time"], + reverse=True + ) + + except Exception as e: + logger.error(LogFormatter.error(f"Failed to get submissions for {organization}", e)) + raise diff --git a/backend/app/services/rate_limiter.py b/backend/app/services/rate_limiter.py new file mode 100644 index 0000000000000000000000000000000000000000..a07c9c22e2e7bd0887dee8492cf4df6ea39b7cff --- /dev/null +++ b/backend/app/services/rate_limiter.py @@ -0,0 +1,72 @@ +""" +import logging +from datetime import datetime, timedelta, timezone +from typing import Tuple, Dict, List + +logger = logging.getLogger(__name__) + +class RateLimiter: + def __init__(self, period_days: int = 7, quota: int = 5): + self.period_days = period_days + self.quota = quota + self.submission_history: Dict[str, List[datetime]] = {} + self.higher_quota_users = set() # Users with higher quotas + self.unlimited_users = set() # Users with no quota limits + + def add_unlimited_user(self, user_id: str): + """Add a user to the unlimited users list""" + self.unlimited_users.add(user_id) + + def add_higher_quota_user(self, user_id: str): + """Add a user to the higher quota users list""" + self.higher_quota_users.add(user_id) + + def record_submission(self, user_id: str): + """Record a new submission for a user""" + current_time = datetime.now(timezone.utc) + if user_id not in self.submission_history: + self.submission_history[user_id] = [] + self.submission_history[user_id].append(current_time) + + def clean_old_submissions(self, user_id: str): + """Remove submissions older than the period""" + if user_id not in self.submission_history: + return + + current_time = datetime.now(timezone.utc) + cutoff_time = current_time - timedelta(days=self.period_days) + + self.submission_history[user_id] = [ + time for time in self.submission_history[user_id] + if time > cutoff_time + ] + + async def check_rate_limit(self, user_id: str) -> Tuple[bool, str]: + """Check if a user has exceeded their rate limit + + Returns: + Tuple[bool, str]: (is_allowed, error_message) + """ + # Unlimited users bypass all checks + if user_id in self.unlimited_users: + return True, "" + + # Clean old submissions + self.clean_old_submissions(user_id) + + # Get current submission count + submission_count = len(self.submission_history.get(user_id, [])) + + # Calculate user's quota + user_quota = self.quota * 2 if user_id in self.higher_quota_users else self.quota + + # Check if user has exceeded their quota + if submission_count >= user_quota: + error_msg = ( + f"User '{user_id}' has reached the limit of {user_quota} submissions " + f"in the last {self.period_days} days. Please wait before submitting again." + ) + return False, error_msg + + return True, "" +""" \ No newline at end of file diff --git a/backend/app/services/votes.py b/backend/app/services/votes.py new file mode 100644 index 0000000000000000000000000000000000000000..87d5642fa1f1be50e89166b1b27e7a0aa91e2aaa --- /dev/null +++ b/backend/app/services/votes.py @@ -0,0 +1,441 @@ +from datetime import datetime, timezone +from typing import Dict, Any, List, Set, Tuple, Optional +import json +import logging +import asyncio +from pathlib import Path +import aiohttp +from huggingface_hub import HfApi +import tempfile +import os + +from app.services.hf_service import HuggingFaceService +from app.config import HF_TOKEN +from app.config.hf_config import HF_ORGANIZATION +from app.core.cache import cache_config +from app.core.formatting import LogFormatter + +logger = logging.getLogger(__name__) + +class VoteService(HuggingFaceService): + _instance: Optional['VoteService'] = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super(VoteService, cls).__new__(cls) + return cls._instance + + def __init__(self): + if not hasattr(self, '_init_done'): + super().__init__() + self.votes_file = cache_config.votes_file + self.votes_to_upload: List[Dict[str, Any]] = [] + self.vote_check_set: Set[Tuple[str, str, str, str]] = set() + self._votes_by_model: Dict[str, List[Dict[str, Any]]] = {} + self._votes_by_user: Dict[str, List[Dict[str, Any]]] = {} + self._last_sync = None + self._sync_interval = 300 # 5 minutes + self._total_votes = 0 + self._last_vote_timestamp = None + self._max_retries = 3 + self._retry_delay = 1 # seconds + self.hf_api = HfApi(token=HF_TOKEN) + self._init_done = True + + async def initialize(self): + """Initialize the vote service""" + if self._initialized: + await self._check_for_new_votes() + return + + try: + logger.info(LogFormatter.section("VOTE SERVICE INITIALIZATION")) + + # Ensure votes directory exists + self.votes_file.parent.mkdir(parents=True, exist_ok=True) + + # Load remote votes + remote_votes = await self._fetch_remote_votes() + if remote_votes: + logger.info(LogFormatter.info(f"Loaded {len(remote_votes)} votes from hub")) + + # Save to local file + with open(self.votes_file, 'w') as f: + for vote in remote_votes: + json.dump(vote, f) + f.write('\n') + + # Load into memory + await self._load_existing_votes() + else: + logger.warning(LogFormatter.warning("No votes found on hub")) + + self._initialized = True + self._last_sync = datetime.now(timezone.utc) + + # Final summary + stats = { + "Total_Votes": self._total_votes, + "Last_Sync": self._last_sync.strftime("%Y-%m-%d %H:%M:%S UTC") + } + logger.info(LogFormatter.section("INITIALIZATION COMPLETE")) + for line in LogFormatter.stats(stats): + logger.info(line) + + except Exception as e: + logger.error(LogFormatter.error("Initialization failed", e)) + raise + + async def _fetch_remote_votes(self) -> List[Dict[str, Any]]: + """Fetch votes from HF hub""" + url = f"https://huggingface.co/datasets/{HF_ORGANIZATION}/votes/raw/main/votes_data.jsonl" + headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} + + try: + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as response: + if response.status == 200: + votes = [] + async for line in response.content: + if line.strip(): + try: + vote = json.loads(line.decode()) + votes.append(vote) + except json.JSONDecodeError: + continue + return votes + else: + logger.error(f"Failed to get remote votes: HTTP {response.status}") + return [] + except Exception as e: + logger.error(f"Error fetching remote votes: {str(e)}") + return [] + + async def _check_for_new_votes(self): + """Check for new votes on the hub and sync if needed""" + try: + remote_votes = await self._fetch_remote_votes() + if len(remote_votes) != self._total_votes: + logger.info(f"Vote count changed: Local ({self._total_votes}) ≠ Remote ({len(remote_votes)})") + # Save to local file + with open(self.votes_file, 'w') as f: + for vote in remote_votes: + json.dump(vote, f) + f.write('\n') + + # Reload into memory + await self._load_existing_votes() + else: + logger.info("Votes are in sync") + + except Exception as e: + logger.error(f"Error checking for new votes: {str(e)}") + + async def _sync_with_hub(self): + """Sync votes with HuggingFace hub""" + try: + logger.info(LogFormatter.section("VOTE SYNC")) + + # Get current remote votes + remote_votes = await self._fetch_remote_votes() + logger.info(LogFormatter.info(f"Loaded {len(remote_votes)} votes from hub")) + + # If we have pending votes to upload + if self.votes_to_upload: + logger.info(LogFormatter.info(f"Adding {len(self.votes_to_upload)} pending votes...")) + + # Add new votes to remote votes + remote_votes.extend(self.votes_to_upload) + + # Create temporary file with all votes + with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as temp_file: + for vote in remote_votes: + json.dump(vote, temp_file) + temp_file.write('\n') + temp_path = temp_file.name + + try: + # Upload JSONL file directly + self.hf_api.upload_file( + path_or_fileobj=temp_path, + path_in_repo="votes_data.jsonl", + repo_id=f"{HF_ORGANIZATION}/votes", + repo_type="dataset", + commit_message=f"Update votes: +{len(self.votes_to_upload)} new votes", + token=self.token + ) + + # Clear pending votes only if upload succeeded + self.votes_to_upload.clear() + logger.info(LogFormatter.success("Pending votes uploaded successfully")) + + except Exception as e: + logger.error(LogFormatter.error("Failed to upload votes to hub", e)) + raise + finally: + # Clean up temp file + os.unlink(temp_path) + + # Update local state + with open(self.votes_file, 'w') as f: + for vote in remote_votes: + json.dump(vote, f) + f.write('\n') + + # Reload votes in memory + await self._load_existing_votes() + logger.info(LogFormatter.success("Sync completed successfully")) + + self._last_sync = datetime.now(timezone.utc) + + except Exception as e: + logger.error(LogFormatter.error("Sync failed", e)) + raise + + async def _load_existing_votes(self): + """Load existing votes from file""" + if not self.votes_file.exists(): + logger.warning(LogFormatter.warning("No votes file found")) + return + + try: + logger.info(LogFormatter.section("LOADING VOTES")) + + # Clear existing data structures + self.vote_check_set.clear() + self._votes_by_model.clear() + self._votes_by_user.clear() + + vote_count = 0 + latest_timestamp = None + + with open(self.votes_file, "r") as f: + for line in f: + try: + vote = json.loads(line.strip()) + vote_count += 1 + + # Track latest timestamp + try: + vote_timestamp = datetime.fromisoformat(vote["timestamp"].replace("Z", "+00:00")) + if not latest_timestamp or vote_timestamp > latest_timestamp: + latest_timestamp = vote_timestamp + vote["timestamp"] = vote_timestamp.strftime("%Y-%m-%dT%H:%M:%SZ") + except (KeyError, ValueError) as e: + logger.warning(LogFormatter.warning(f"Invalid timestamp in vote: {str(e)}")) + continue + + if vote_count % 1000 == 0: + logger.info(LogFormatter.info(f"Processed {vote_count:,} votes...")) + + self._add_vote_to_memory(vote) + + except json.JSONDecodeError as e: + logger.error(LogFormatter.error("Vote parsing failed", e)) + continue + except Exception as e: + logger.error(LogFormatter.error("Vote processing failed", e)) + continue + + self._total_votes = vote_count + self._last_vote_timestamp = latest_timestamp + + # Final summary + stats = { + "Total_Votes": vote_count, + "Latest_Vote": latest_timestamp.strftime("%Y-%m-%d %H:%M:%S UTC") if latest_timestamp else "None", + "Unique_Models": len(self._votes_by_model), + "Unique_Users": len(self._votes_by_user) + } + + logger.info(LogFormatter.section("VOTE SUMMARY")) + for line in LogFormatter.stats(stats): + logger.info(line) + + except Exception as e: + logger.error(LogFormatter.error("Failed to load votes", e)) + raise + + def _add_vote_to_memory(self, vote: Dict[str, Any]): + """Add vote to memory structures""" + try: + # Create a unique identifier tuple that includes precision + check_tuple = ( + vote["model"], + vote.get("revision", "main"), + vote["username"], + vote.get("precision", "unknown") + ) + + # Skip if we already have this vote + if check_tuple in self.vote_check_set: + return + + self.vote_check_set.add(check_tuple) + + # Update model votes + if vote["model"] not in self._votes_by_model: + self._votes_by_model[vote["model"]] = [] + self._votes_by_model[vote["model"]].append(vote) + + # Update user votes + if vote["username"] not in self._votes_by_user: + self._votes_by_user[vote["username"]] = [] + self._votes_by_user[vote["username"]].append(vote) + + except KeyError as e: + logger.error(LogFormatter.error("Malformed vote data, missing key", str(e))) + except Exception as e: + logger.error(LogFormatter.error("Error adding vote to memory", str(e))) + + async def get_user_votes(self, user_id: str) -> List[Dict[str, Any]]: + """Get all votes from a specific user""" + logger.info(LogFormatter.info(f"Fetching votes for user: {user_id}")) + + # Check if we need to refresh votes + if (datetime.now(timezone.utc) - self._last_sync).total_seconds() > self._sync_interval: + logger.info(LogFormatter.info("Cache expired, refreshing votes...")) + await self._check_for_new_votes() + + votes = self._votes_by_user.get(user_id, []) + logger.info(LogFormatter.success(f"Found {len(votes):,} votes")) + return votes + + async def get_model_votes(self, model_id: str) -> Dict[str, Any]: + """Get all votes for a specific model""" + logger.info(LogFormatter.info(f"Fetching votes for model: {model_id}")) + + # Check if we need to refresh votes + if (datetime.now(timezone.utc) - self._last_sync).total_seconds() > self._sync_interval: + logger.info(LogFormatter.info("Cache expired, refreshing votes...")) + await self._check_for_new_votes() + + votes = self._votes_by_model.get(model_id, []) + + # Group votes by revision and precision + votes_by_config = {} + for vote in votes: + revision = vote.get("revision", "main") + precision = vote.get("precision", "unknown") + config_key = f"{revision}_{precision}" + if config_key not in votes_by_config: + votes_by_config[config_key] = { + "revision": revision, + "precision": precision, + "count": 0 + } + votes_by_config[config_key]["count"] += 1 + + stats = { + "Total_Votes": len(votes), + **{f"Config_{k}": v["count"] for k, v in votes_by_config.items()} + } + + logger.info(LogFormatter.section("VOTE STATISTICS")) + for line in LogFormatter.stats(stats): + logger.info(line) + + return { + "total_votes": len(votes), + "votes_by_config": votes_by_config, + "votes": votes + } + + async def _get_model_revision(self, model_id: str) -> str: + """Get current revision of a model with retries""" + logger.info(f"Getting revision for model: {model_id}") + for attempt in range(self._max_retries): + try: + model_info = await asyncio.to_thread(self.hf_api.model_info, model_id) + logger.info(f"Successfully got revision {model_info.sha} for model {model_id}") + return model_info.sha + except Exception as e: + logger.error(f"Error getting model revision for {model_id} (attempt {attempt + 1}): {str(e)}") + if attempt < self._max_retries - 1: + retry_delay = self._retry_delay * (attempt + 1) + logger.info(f"Retrying in {retry_delay} seconds...") + await asyncio.sleep(retry_delay) + else: + logger.warning(f"Using 'main' as fallback revision for {model_id} after {self._max_retries} failed attempts") + return "main" + + async def add_vote(self, model_id: str, user_id: str, vote_type: str, vote_data: Dict[str, Any] = None) -> Dict[str, Any]: + """Add a vote for a model""" + try: + self._log_repo_operation("add", f"{HF_ORGANIZATION}/votes", f"Adding {vote_type} vote for {model_id} by {user_id}") + logger.info(LogFormatter.section("NEW VOTE")) + stats = { + "Model": model_id, + "User": user_id, + "Type": vote_type, + "Config": vote_data or {} + } + for line in LogFormatter.tree(stats, "Vote Details"): + logger.info(line) + + # Use provided configuration or fallback to model info + precision = None + revision = None + + if vote_data: + precision = vote_data.get("precision") + revision = vote_data.get("revision") + + # If any info is missing, try to get it from model info + if not all([precision, revision]): + try: + model_info = await asyncio.to_thread(self.hf_api.model_info, model_id) + model_card_data = model_info.cardData if hasattr(model_info, 'cardData') else {} + + if not precision: + precision = model_card_data.get("precision", "unknown") + if not revision: + revision = model_info.sha + except Exception as e: + logger.warning(LogFormatter.warning(f"Failed to get model info: {str(e)}. Using default values.")) + precision = precision or "unknown" + revision = revision or "main" + + # Check if vote already exists with this configuration + check_tuple = (model_id, revision, user_id, precision) + + if check_tuple in self.vote_check_set: + raise ValueError(f"Vote already recorded for this model configuration (precision: {precision}, revision: {revision[:7] if revision else 'unknown'})") + + vote = { + "model": model_id, + "revision": revision, + "username": user_id, + "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), + "vote_type": vote_type, + "precision": precision + } + + # Update local storage + with open(self.votes_file, "a") as f: + f.write(json.dumps(vote) + "\n") + + self._add_vote_to_memory(vote) + self.votes_to_upload.append(vote) + + stats = { + "Status": "Success", + "Queue_Size": len(self.votes_to_upload), + "Model_Config": { + "Precision": precision, + "Revision": revision[:7] if revision else "unknown" + } + } + for line in LogFormatter.stats(stats): + logger.info(line) + + # Force immediate sync + logger.info(LogFormatter.info("Forcing immediate sync with hub")) + await self._sync_with_hub() + + return {"status": "success", "message": "Vote added successfully"} + + except Exception as e: + logger.error(LogFormatter.error("Failed to add vote", e)) + raise \ No newline at end of file diff --git a/backend/app/utils/__init__.py b/backend/app/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69a93acb760828c13400cfcd19da2822dfd83e5e --- /dev/null +++ b/backend/app/utils/__init__.py @@ -0,0 +1,3 @@ +from . import model_validation + +__all__ = ["model_validation"] diff --git a/backend/app/utils/logging.py b/backend/app/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..3a720f0c226faa0d0390a0c561be75db0194ca7f --- /dev/null +++ b/backend/app/utils/logging.py @@ -0,0 +1,3 @@ +from app.core.formatting import LogFormatter + +__all__ = ['LogFormatter'] \ No newline at end of file diff --git a/backend/app/utils/model_validation.py b/backend/app/utils/model_validation.py new file mode 100644 index 0000000000000000000000000000000000000000..e7cbb05e31143d4507377a0969032085b8c696ce --- /dev/null +++ b/backend/app/utils/model_validation.py @@ -0,0 +1,168 @@ +import json +import logging +import asyncio +from typing import Tuple, Optional, Dict, Any +from datasets import load_dataset +from huggingface_hub import HfApi, ModelCard, hf_hub_download +from huggingface_hub import hf_api +from transformers import AutoConfig, AutoTokenizer +from app.config.base import HF_TOKEN +from app.core.formatting import LogFormatter + +logger = logging.getLogger(__name__) + +class ModelValidator: + def __init__(self): + self.token = HF_TOKEN + self.api = HfApi(token=self.token) + self.headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} + + async def check_model_card(self, model_id: str) -> Tuple[bool, str, Optional[Dict[str, Any]]]: + """Check if model has a valid model card""" + try: + logger.info(LogFormatter.info(f"Checking model card for {model_id}")) + + # Get model card content using ModelCard.load + try: + model_card = await asyncio.to_thread( + ModelCard.load, + model_id + ) + logger.info(LogFormatter.success("Model card found")) + except Exception as e: + error_msg = "Please add a model card to your model to explain how you trained/fine-tuned it." + logger.error(LogFormatter.error(error_msg, e)) + return False, error_msg, None + + # Check license in model card data + if model_card.data.license is None and not ("license_name" in model_card.data and "license_link" in model_card.data): + error_msg = "License not found. Please add a license to your model card using the `license` metadata or a `license_name`/`license_link` pair." + logger.warning(LogFormatter.warning(error_msg)) + return False, error_msg, None + + # Enforce card content length + if len(model_card.text) < 200: + error_msg = "Please add a description to your model card, it is too short." + logger.warning(LogFormatter.warning(error_msg)) + return False, error_msg, None + + logger.info(LogFormatter.success("Model card validation passed")) + return True, "", model_card + + except Exception as e: + error_msg = "Failed to validate model card" + logger.error(LogFormatter.error(error_msg, e)) + return False, str(e), None + + async def get_safetensors_metadata(self, model_id: str, is_adapter: bool = False, revision: str = "main") -> Optional[Dict]: + """Get metadata from a safetensors file""" + try: + if is_adapter: + metadata = await asyncio.to_thread( + hf_api.parse_safetensors_file_metadata, + model_id, + "adapter_model.safetensors", + token=self.token, + revision=revision, + ) + else: + metadata = await asyncio.to_thread( + hf_api.get_safetensors_metadata, + repo_id=model_id, + token=self.token, + revision=revision, + ) + return metadata + + except Exception as e: + logger.error(f"Failed to get safetensors metadata: {str(e)}") + return None + + async def get_model_size( + self, + model_info: Any, + precision: str, + base_model: str, + revision: str + ) -> Tuple[Optional[float], Optional[str]]: + """Get model size in billions of parameters""" + try: + logger.info(LogFormatter.info(f"Checking model size for {model_info.modelId}")) + + # Check if model is adapter + is_adapter = any(s.rfilename == "adapter_config.json" for s in model_info.siblings if hasattr(s, 'rfilename')) + + # Try to get size from safetensors first + model_size = None + + if is_adapter and base_model: + # For adapters, we need both adapter and base model sizes + adapter_meta = await self.get_safetensors_metadata(model_info.id, is_adapter=True, revision=revision) + base_meta = await self.get_safetensors_metadata(base_model, revision="main") + + if adapter_meta and base_meta: + adapter_size = sum(adapter_meta.parameter_count.values()) + base_size = sum(base_meta.parameter_count.values()) + model_size = adapter_size + base_size + else: + # For regular models, just get the model size + meta = await self.get_safetensors_metadata(model_info.id, revision=revision) + if meta: + model_size = sum(meta.parameter_count.values()) # total params + + if model_size is None: + # If model size could not be determined, return an error + return None, "Model size could not be determined" + + # Adjust size for GPTQ models + size_factor = 8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1 + model_size = model_size / 1e9 # Convert to billions, assuming float16 + model_size = round(size_factor * model_size, 3) + + logger.info(LogFormatter.success(f"Model size: {model_size}B parameters")) + return model_size, None + + except Exception as e: + logger.error(LogFormatter.error(f"Error while determining model size: {e}")) + return None, str(e) + + async def is_model_on_hub( + self, + model_name: str, + revision: str, + test_tokenizer: bool = False, + trust_remote_code: bool = False + ) -> Tuple[bool, Optional[str], Optional[Any]]: + """Check if model exists and is properly configured on the Hub""" + try: + config = await asyncio.to_thread( + AutoConfig.from_pretrained, + model_name, + revision=revision, + trust_remote_code=trust_remote_code, + token=self.token, + force_download=True + ) + + if test_tokenizer: + try: + await asyncio.to_thread( + AutoTokenizer.from_pretrained, + model_name, + revision=revision, + trust_remote_code=trust_remote_code, + token=self.token + ) + except ValueError as e: + return False, f"The tokenizer is not available in an official Transformers release: {e}", None + except Exception: + return False, "The tokenizer cannot be loaded. Ensure the tokenizer class is part of a stable Transformers release and correctly configured.", None + + return True, None, config + + except ValueError: + return False, "The model requires `trust_remote_code=True` to launch, and for safety reasons, we don't accept such models automatically.", None + except Exception as e: + if "You are trying to access a gated repo." in str(e): + return True, "The model is gated and requires special access permissions.", None + return False, f"The model was not found or is misconfigured on the Hub. Error: {e.args[0]}", None diff --git a/backend/pyproject.toml b/backend/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..8ce38744fc266b83b0abec4d98c5a415eeeaec46 --- /dev/null +++ b/backend/pyproject.toml @@ -0,0 +1,31 @@ +[tool.poetry] +name = "eeg-leaderboard-backend" +version = "0.1.0" +description = "Backend for the EEG Finetune Arena" +authors = ["Braindecode Team"] + +[tool.poetry.dependencies] +python = "^3.12" +fastapi = "^0.115.6" +uvicorn = {extras = ["standard"], version = "^0.34.0"} +numpy = "^2.2.0" +pandas = "^2.2.3" +datasets = "^3.3.2" +pyarrow = "^18.1.0" +python-multipart = "^0.0.20" +huggingface-hub = "0.29.1" +transformers = "4.49.0" +safetensors = "^0.5.3" +aiofiles = "^24.1.0" +fastapi-cache2 = "^0.2.1" +python-dotenv = "^1.0.1" + +[tool.poetry.group.dev.dependencies] +pytest = "^8.3.4" +black = "^24.10.0" +isort = "^5.13.2" +flake8 = "^6.1.0" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..88ba92218b59234736e2912f4de0ff10c3ec3a04 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,33 @@ +services: + backend: + build: + context: ./backend + dockerfile: Dockerfile.dev + args: + - HF_TOKEN=${HF_TOKEN} + ports: + - "${BACKEND_PORT:-8000}:8000" + volumes: + - ./backend:/app + environment: + - ENVIRONMENT=${ENVIRONMENT:-development} + - HF_TOKEN=${HF_TOKEN} + - HF_HOME=${HF_HOME:-/.cache} + command: uvicorn app.asgi:app --host 0.0.0.0 --port 8000 --reload + + frontend: + build: + context: ./frontend + dockerfile: Dockerfile.dev + ports: + - "${FRONTEND_PORT:-7860}:7860" + volumes: + - ./frontend:/app + - /app/node_modules + environment: + - NODE_ENV=${ENVIRONMENT:-development} + - CHOKIDAR_USEPOLLING=true + - PORT=${FRONTEND_PORT:-7860} + command: npm start + stdin_open: true + tty: true diff --git a/frontend/Dockerfile.dev b/frontend/Dockerfile.dev new file mode 100644 index 0000000000000000000000000000000000000000..259f7c9d8746db26bee8ee531d96cbe0d619321e --- /dev/null +++ b/frontend/Dockerfile.dev @@ -0,0 +1,15 @@ +FROM node:18 + +WORKDIR /app + +# Install required global dependencies +RUN npm install -g react-scripts + +# Copy package.json and package-lock.json +COPY package*.json ./ + +# Install project dependencies +RUN npm install + +# Volume will be mounted here, no need for COPY +CMD ["npm", "start"] \ No newline at end of file diff --git a/frontend/package.json b/frontend/package.json new file mode 100644 index 0000000000000000000000000000000000000000..f49fc4062d122dfb556ffe3631381d82b08b8a8a --- /dev/null +++ b/frontend/package.json @@ -0,0 +1,55 @@ +{ + "name": "eeg-finetune-arena", + "version": "0.1.0", + "private": true, + "dependencies": { + "@emotion/react": "^11.13.3", + "@emotion/styled": "^11.13.0", + "@huggingface/hub": "^0.14.0", + "@mui/icons-material": "^6.1.7", + "@mui/lab": "^6.0.0-beta.16", + "@mui/material": "^6.1.6", + "@mui/x-data-grid": "^7.22.2", + "@tanstack/react-query": "^5.62.2", + "@tanstack/react-table": "^8.20.5", + "@tanstack/react-virtual": "^3.10.9", + "@testing-library/jest-dom": "^5.17.0", + "@testing-library/react": "^13.4.0", + "@testing-library/user-event": "^13.5.0", + "compression": "^1.7.4", + "cors": "^2.8.5", + "express": "^4.18.2", + "react": "^18.3.1", + "react-dom": "^18.3.1", + "react-router-dom": "^6.28.0", + "react-scripts": "5.0.1", + "serve-static": "^1.15.0", + "web-vitals": "^2.1.4" + }, + "scripts": { + "start": "react-scripts start", + "build": "react-scripts build", + "test": "react-scripts test", + "eject": "react-scripts eject", + "serve": "node server.js" + }, + "eslintConfig": { + "extends": [ + "react-app", + "react-app/jest" + ] + }, + "browserslist": { + "production": [ + ">0.2%", + "not dead", + "not op_mini all" + ], + "development": [ + "last 1 chrome version", + "last 1 firefox version", + "last 1 safari version" + ] + }, + "proxy": "http://backend:8000" +} diff --git a/frontend/public/index.html b/frontend/public/index.html new file mode 100644 index 0000000000000000000000000000000000000000..1818e92bf7929d4b0d0e3ef5d25ebcf8c71ff0f4 --- /dev/null +++ b/frontend/public/index.html @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + EEG Finetune Arena + + + +
+ + diff --git a/frontend/public/logo256.png b/frontend/public/logo256.png new file mode 100644 index 0000000000000000000000000000000000000000..58547e134af0ac1200a4608fb1c800b3e8e9ddf1 Binary files /dev/null and b/frontend/public/logo256.png differ diff --git a/frontend/public/logo32.png b/frontend/public/logo32.png new file mode 100644 index 0000000000000000000000000000000000000000..1b6e8fbd42dd1bcc599649bf6f230fde89a6908a Binary files /dev/null and b/frontend/public/logo32.png differ diff --git a/frontend/public/robots.txt b/frontend/public/robots.txt new file mode 100644 index 0000000000000000000000000000000000000000..e9e57dc4d41b9b46e05112e9f45b7ea6ac0ba15e --- /dev/null +++ b/frontend/public/robots.txt @@ -0,0 +1,3 @@ +# https://www.robotstxt.org/robotstxt.html +User-agent: * +Disallow: diff --git a/frontend/server.js b/frontend/server.js new file mode 100644 index 0000000000000000000000000000000000000000..653befea69419568b117ce809871639d86d65581 --- /dev/null +++ b/frontend/server.js @@ -0,0 +1,85 @@ +const express = require("express"); +const cors = require("cors"); +const compression = require("compression"); +const path = require("path"); +const serveStatic = require("serve-static"); +const { createProxyMiddleware } = require("http-proxy-middleware"); + +const app = express(); +const port = process.env.PORT || 7860; +const apiPort = process.env.INTERNAL_API_PORT || 7861; + +// Enable CORS for all routes +app.use(cors()); + +// Enable GZIP compression +app.use(compression()); + +// Proxy all API requests to the Python backend +app.use( + "/api", + createProxyMiddleware({ + target: `http://127.0.0.1:${apiPort}`, + changeOrigin: true, + onError: (err, req, res) => { + console.error("Proxy Error:", err); + res.status(500).json({ error: "Proxy Error", details: err.message }); + }, + }) +); + +// Serve static files from the build directory +app.use( + express.static(path.join(__dirname, "build"), { + // Don't cache HTML files + setHeaders: (res, path) => { + if (path.endsWith(".html")) { + res.setHeader("Cache-Control", "no-cache, no-store, must-revalidate"); + res.setHeader("Pragma", "no-cache"); + res.setHeader("Expires", "0"); + } else { + // Cache other static resources for 1 year + res.setHeader("Cache-Control", "public, max-age=31536000"); + } + }, + }) +); + +// Middleware to preserve URL parameters +app.use((req, res, next) => { + // Don't interfere with API requests + if (req.url.startsWith("/api")) { + return next(); + } + + // Preserve original URL parameters + req.originalUrl = req.url; + next(); +}); + +// Handle all other routes by serving index.html +app.get("*", (req, res) => { + // Don't interfere with API requests + if (req.url.startsWith("/api")) { + return next(); + } + + // Headers for client-side routing + res.set({ + "Cache-Control": "no-cache, no-store, must-revalidate", + Pragma: "no-cache", + Expires: "0", + }); + + // Send index.html for all other routes + res.sendFile(path.join(__dirname, "build", "index.html")); +}); + +app.listen(port, "0.0.0.0", () => { + console.log( + `Frontend server is running on port ${port} in ${ + process.env.NODE_ENV || "development" + } mode` + ); + console.log(`API proxy target: http://127.0.0.1:${apiPort}`); +}); diff --git a/frontend/src/App.js b/frontend/src/App.js new file mode 100644 index 0000000000000000000000000000000000000000..cf53b89670ac078b14a6a03297c5672b4beeaa66 --- /dev/null +++ b/frontend/src/App.js @@ -0,0 +1,127 @@ +import React, { useEffect } from "react"; +import { + HashRouter as Router, + Routes, + Route, + useSearchParams, + useLocation, +} from "react-router-dom"; +import { ThemeProvider } from "@mui/material/styles"; +import { Box, CssBaseline } from "@mui/material"; +import Navigation from "./components/Navigation/Navigation"; +import LeaderboardPage from "./pages/LeaderboardPage/LeaderboardPage"; +import AddModelPage from "./pages/AddModelPage/AddModelPage"; +import VoteModelPage from "./pages/VoteModelPage/VoteModelPage"; +import AboutPage from "./pages/AboutPage/AboutPage"; +import QuotePage from "./pages/QuotePage/QuotePage"; +import Footer from "./components/Footer/Footer"; +import getTheme from "./config/theme"; +import { useThemeMode } from "./hooks/useThemeMode"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import LeaderboardProvider from "./pages/LeaderboardPage/components/Leaderboard/context/LeaderboardContext"; + +const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: 1, + refetchOnWindowFocus: false, + }, + }, +}); + +function UrlHandler() { + const location = useLocation(); + const [searchParams] = useSearchParams(); + + // Synchronize URL with parent HF page + useEffect(() => { + // Check if we're in an HF Space iframe + const isHFSpace = window.location !== window.parent.location; + if (!isHFSpace) return; + + // Sync query and hash from this embedded app to the parent page URL + const queryString = window.location.search; + const hash = window.location.hash; + + // HF Spaces' special message type to update the query string and the hash in the parent page URL + window.parent.postMessage( + { + queryString, + hash, + }, + "https://huggingface.co" + ); + }, [location, searchParams]); + + // Read the updated hash reactively + useEffect(() => { + const handleHashChange = (event) => { + console.log("hash change event", event); + }; + + window.addEventListener("hashchange", handleHashChange); + return () => window.removeEventListener("hashchange", handleHashChange); + }, []); + + return null; +} + +function App() { + const { mode, toggleTheme } = useThemeMode(); + const theme = getTheme(mode); + + return ( +
+ + + + + + + + + + + } /> + } /> + } /> + } /> + } /> + + +
+ ); +} + +export default App; diff --git a/frontend/src/components/Footer/Footer.js b/frontend/src/components/Footer/Footer.js new file mode 100644 index 0000000000000000000000000000000000000000..ae94e451a8f512ff24ce6efd246c22adb3b77d75 --- /dev/null +++ b/frontend/src/components/Footer/Footer.js @@ -0,0 +1,29 @@ +import React from "react"; +import { Box, Typography, Link } from "@mui/material"; + +const Footer = () => { + return ( + + + Braindecode - EEG Finetune Arena -{" "} + + braindecode.org + + + + ); +}; + +export default Footer; diff --git a/frontend/src/components/Logo/HFLogo.js b/frontend/src/components/Logo/HFLogo.js new file mode 100644 index 0000000000000000000000000000000000000000..e49263da5f52e62f50db806f6f295d94e75be47f --- /dev/null +++ b/frontend/src/components/Logo/HFLogo.js @@ -0,0 +1,19 @@ +import React from 'react'; + +const HFLogo = () => ( + + hg-logo + + +); + +export default HFLogo; \ No newline at end of file diff --git a/frontend/src/components/Logo/Logo.js b/frontend/src/components/Logo/Logo.js new file mode 100644 index 0000000000000000000000000000000000000000..5216331a837a48b571bc56727eb3f39aa45ae3b8 --- /dev/null +++ b/frontend/src/components/Logo/Logo.js @@ -0,0 +1,56 @@ +import React from "react"; +import { useNavigate, useSearchParams, useLocation } from "react-router-dom"; +import { Box } from "@mui/material"; +import HFLogo from "./HFLogo"; +import { useLeaderboard } from "../../pages/LeaderboardPage/components/Leaderboard/context/LeaderboardContext"; + +const Logo = ({ height = "40px" }) => { + const navigate = useNavigate(); + const [searchParams, setSearchParams] = useSearchParams(); + const location = useLocation(); + const { actions } = useLeaderboard(); + + const handleReset = () => { + // Reset all leaderboard state first + actions.resetAll(); + + // Then clean URL in one go + if ( + location.pathname !== "/" || + searchParams.toString() !== "" || + location.hash !== "" + ) { + window.history.replaceState(null, "", "/"); + navigate("/", { replace: true, state: { skipUrlSync: true } }); + setSearchParams({}, { replace: true, state: { skipUrlSync: true } }); + } + }; + + return ( + + + + + + ); +}; + +export default Logo; diff --git a/frontend/src/components/Navigation/Navigation.js b/frontend/src/components/Navigation/Navigation.js new file mode 100644 index 0000000000000000000000000000000000000000..453ef46aeb1a577a2845e839af209375c55687e1 --- /dev/null +++ b/frontend/src/components/Navigation/Navigation.js @@ -0,0 +1,379 @@ +import React, { useState } from "react"; +import { + AppBar, + Toolbar, + Box, + IconButton, + Tooltip, + ButtonBase, + Typography, +} from "@mui/material"; +import { useLocation, useNavigate, useSearchParams } from "react-router-dom"; +import LightModeOutlinedIcon from "@mui/icons-material/LightModeOutlined"; +import DarkModeOutlinedIcon from "@mui/icons-material/DarkModeOutlined"; +import { alpha } from "@mui/material/styles"; +import MenuIcon from "@mui/icons-material/Menu"; +import { Menu, MenuItem, useMediaQuery, useTheme } from "@mui/material"; + +const Navigation = ({ onToggleTheme, mode }) => { + const location = useLocation(); + const navigate = useNavigate(); + const [searchParams] = useSearchParams(); + const [anchorEl, setAnchorEl] = useState(null); + const theme = useTheme(); + const isMobile = useMediaQuery(theme.breakpoints.down("md")); + const [hasChanged, setHasChanged] = useState(false); + + const handleThemeToggle = () => { + setHasChanged(true); + onToggleTheme(); + }; + + const iconStyle = { + fontSize: "1.125rem", + ...(hasChanged && { + animation: "rotateIn 0.3s cubic-bezier(0.4, 0, 0.2, 1)", + "@keyframes rotateIn": { + "0%": { + opacity: 0, + transform: + mode === "light" + ? "rotate(-90deg) scale(0.8)" + : "rotate(90deg) scale(0.8)", + }, + "100%": { + opacity: 1, + transform: "rotate(0) scale(1)", + }, + }, + }), + }; + + // Function to sync URL with parent HF page + const syncUrlWithParent = (queryString, hash) => { + // Check if we're in an HF Space iframe + const isHFSpace = window.location !== window.parent.location; + if (isHFSpace) { + try { + // Build complete URL with hash + const fullPath = `${queryString}${hash ? "#" + hash : ""}`; + window.parent.postMessage( + { + type: "urlUpdate", + path: fullPath, + }, + "https://huggingface.co" + ); + } catch (e) { + console.warn("Unable to sync URL with parent:", e); + } + } + }; + + const linkStyle = (isActive = false) => ({ + textDecoration: "none", + color: isActive ? "text.primary" : "text.secondary", + fontSize: "0.8125rem", + opacity: isActive ? 1 : 0.8, + display: "flex", + alignItems: "center", + gap: 0.5, + paddingBottom: "2px", + cursor: "pointer", + position: "relative", + "&:hover": { + opacity: 1, + color: "text.primary", + }, + "&::after": isActive + ? { + content: '""', + position: "absolute", + bottom: "-4px", + left: "0", + width: "100%", + height: "2px", + backgroundColor: (theme) => + alpha( + theme.palette.text.primary, + theme.palette.mode === "dark" ? 0.3 : 0.2 + ), + borderRadius: "2px", + } + : {}, + }); + + const Separator = () => ( + ({ + width: "4px", + height: "4px", + borderRadius: "100%", + backgroundColor: alpha( + theme.palette.text.primary, + theme.palette.mode === "dark" ? 0.2 : 0.15 + ), + })} + /> + ); + + const handleNavigation = (path) => (e) => { + e.preventDefault(); + const searchString = searchParams.toString(); + const queryString = searchString ? `?${searchString}` : ""; + const newPath = `${path}${queryString}`; + + // Local navigation via React Router + navigate(newPath); + + // If in HF Space, sync with parent + if (window.location !== window.parent.location) { + syncUrlWithParent(queryString, newPath); + } + }; + + const handleMenuOpen = (event) => { + setAnchorEl(event.currentTarget); + }; + + const handleMenuClose = () => { + setAnchorEl(null); + }; + + const navItems = [ + { path: "/", label: "Leaderboard" }, + { path: "/add-model", label: "Submit" }, + { path: "/vote", label: "Vote" }, + { path: "/about", label: "About" }, + { path: "/quote", label: "Citations" }, + ]; + + return ( + + + {isMobile ? ( + + + + + + + `1px solid ${alpha(theme.palette.divider, 0.1)}`, + backgroundColor: (theme) => + theme.palette.mode === "dark" + ? alpha(theme.palette.background.paper, 0.8) + : theme.palette.background.paper, + backdropFilter: "blur(20px)", + "& .MuiList-root": { + py: 1, + }, + "& .MuiMenuItem-root": { + px: 2, + py: 1, + fontSize: "0.8125rem", + color: "text.secondary", + transition: "all 0.2s ease-in-out", + position: "relative", + "&:hover": { + backgroundColor: (theme) => + alpha( + theme.palette.text.primary, + theme.palette.mode === "dark" ? 0.1 : 0.06 + ), + color: "text.primary", + }, + "&.Mui-selected": { + backgroundColor: "transparent", + color: "text.primary", + "&::after": { + content: '""', + position: "absolute", + left: "8px", + width: "4px", + height: "100%", + top: "0", + backgroundColor: (theme) => + alpha( + theme.palette.text.primary, + theme.palette.mode === "dark" ? 0.3 : 0.2 + ), + borderRadius: "2px", + }, + "&:hover": { + backgroundColor: (theme) => + alpha( + theme.palette.text.primary, + theme.palette.mode === "dark" ? 0.1 : 0.06 + ), + }, + }, + }, + }, + }} + transformOrigin={{ horizontal: "left", vertical: "top" }} + anchorOrigin={{ horizontal: "left", vertical: "bottom" }} + > + {/* Navigation Section */} + + + Navigation + + + {navItems.map((item) => ( + { + handleNavigation(item.path)(e); + handleMenuClose(); + }} + selected={location.pathname === item.path} + > + {item.label} + + ))} + + + + ({ + color: "text.secondary", + borderRadius: "100%", + padding: 0, + width: "36px", + height: "36px", + display: "flex", + alignItems: "center", + justifyContent: "center", + transition: "all 0.2s ease-in-out", + "&:hover": { + color: "text.primary", + backgroundColor: alpha( + theme.palette.text.primary, + theme.palette.mode === "dark" ? 0.1 : 0.06 + ), + }, + "&.MuiButtonBase-root": { + overflow: "hidden", + }, + "& .MuiTouchRipple-root": { + color: alpha(theme.palette.text.primary, 0.3), + }, + })} + > + {mode === "light" ? ( + + ) : ( + + )} + + + + ) : ( + // Desktop version + + {/* Internal navigation */} + + {navItems.map((item) => ( + + {item.label} + + ))} + + + + + {/* Dark mode toggle */} + + ({ + color: "text.secondary", + borderRadius: "100%", + padding: 0, + width: "36px", + height: "36px", + display: "flex", + alignItems: "center", + justifyContent: "center", + transition: "all 0.2s ease-in-out", + "&:hover": { + color: "text.primary", + backgroundColor: alpha( + theme.palette.text.primary, + theme.palette.mode === "dark" ? 0.1 : 0.06 + ), + }, + "&.MuiButtonBase-root": { + overflow: "hidden", + }, + "& .MuiTouchRipple-root": { + color: alpha(theme.palette.text.primary, 0.3), + }, + })} + > + {mode === "light" ? ( + + ) : ( + + )} + + + + )} + + + ); +}; + +export default Navigation; diff --git a/frontend/src/components/shared/AuthContainer.js b/frontend/src/components/shared/AuthContainer.js new file mode 100644 index 0000000000000000000000000000000000000000..ca79ed8645929ab583964e33be5c1810eef620ab --- /dev/null +++ b/frontend/src/components/shared/AuthContainer.js @@ -0,0 +1,168 @@ +import React from "react"; +import { + Box, + Typography, + Button, + Chip, + Stack, + Paper, + CircularProgress, + useTheme, + useMediaQuery, +} from "@mui/material"; +import HFLogo from "../Logo/HFLogo"; +import { useAuth } from "../../hooks/useAuth"; +import LogoutIcon from "@mui/icons-material/Logout"; +import { useNavigate } from "react-router-dom"; + +function AuthContainer({ actionText = "DO_ACTION" }) { + const { isAuthenticated, user, login, logout, loading } = useAuth(); + const navigate = useNavigate(); + const theme = useTheme(); + const isMobile = useMediaQuery(theme.breakpoints.down("sm")); + + const handleLogout = () => { + if (isAuthenticated && logout) { + logout(); + navigate("/", { replace: true }); + window.location.reload(); + } + }; + + if (loading) { + return ( + + + + ); + } + + if (!isAuthenticated) { + return ( + + + Login to {actionText} + + + You need to be logged in with your Hugging Face account to{" "} + {actionText.toLowerCase()} + + + + ); + } + + return ( + + + + + Connected as {user?.username} + + + + + + + ); +} + +export default AuthContainer; diff --git a/frontend/src/components/shared/CodeBlock.js b/frontend/src/components/shared/CodeBlock.js new file mode 100644 index 0000000000000000000000000000000000000000..6f06f6eed1f6a17dd70334d3a7bb4d0ab897355c --- /dev/null +++ b/frontend/src/components/shared/CodeBlock.js @@ -0,0 +1,37 @@ +import React from 'react'; +import { Box, IconButton } from '@mui/material'; +import ContentCopyIcon from '@mui/icons-material/ContentCopy'; + +const CodeBlock = ({ code }) => ( + + navigator.clipboard.writeText(code)} + sx={{ + position: 'absolute', + top: 8, + right: 8, + color: 'grey.500', + '&:hover': { color: 'grey.300' }, + }} + > + + + + {code} + + +); + +export default CodeBlock; \ No newline at end of file diff --git a/frontend/src/components/shared/FilterTag.js b/frontend/src/components/shared/FilterTag.js new file mode 100644 index 0000000000000000000000000000000000000000..3cd154cb61a699bf94a2af0ba78286e3588aa754 --- /dev/null +++ b/frontend/src/components/shared/FilterTag.js @@ -0,0 +1,139 @@ +import React from "react"; +import { Chip } from "@mui/material"; +import { useTheme } from "@mui/material/styles"; +import { alpha } from "@mui/material/styles"; +import CheckBoxOutlineBlankIcon from "@mui/icons-material/CheckBoxOutlineBlank"; +import CheckBoxOutlinedIcon from "@mui/icons-material/CheckBoxOutlined"; + +const FilterTag = ({ + label, + checked, + onChange, + count, + isHideFilter = false, + totalCount = 0, + variant = "tag", + showCheckbox = false, + stacked = false, + sx = {}, +}) => { + const theme = useTheme(); + + const formatCount = (count) => { + if (count === undefined) return ""; + return `${count}`; + }; + + const mainLabel = label; + const countLabel = count !== undefined ? formatCount(count) : ""; + + return ( + + ) : ( + + ) + ) : null + } + label={ + + {mainLabel} + {countLabel && ( + <> + + {countLabel} + + )} + + } + onClick={onChange} + variant="outlined" + color={ + checked + ? variant === "secondary" + ? "secondary" + : "primary" + : "default" + } + size="small" + data-checked={checked} + sx={{ + height: "32px", + fontWeight: 600, + opacity: checked ? 1 : 0.8, + borderRadius: "5px", + borderWidth: "1px", + borderStyle: "solid", + cursor: "pointer", + pl: showCheckbox ? 0.5 : 0, + mr: 0.5, + mb: 0.5, + transition: "opacity 0.2s ease, border-color 0.2s ease", + "& .MuiChip-label": { + px: 0.75, + pl: showCheckbox ? 0.6 : 0.75, + }, + "& .MuiChip-icon": { + mr: 0.5, + pl: 0.2, + }, + "&:hover": { + opacity: 1, + backgroundColor: checked + ? alpha( + theme.palette[variant === "secondary" ? "secondary" : "primary"] + .main, + theme.palette.mode === "light" ? 0.08 : 0.16 + ) + : "action.hover", + borderWidth: "1px", + }, + backgroundColor: checked + ? alpha( + theme.palette[variant === "secondary" ? "secondary" : "primary"] + .main, + theme.palette.mode === "light" ? 0.08 : 0.16 + ) + : "background.paper", + borderColor: checked + ? variant === "secondary" + ? "secondary.main" + : "primary.main" + : "divider", + ...sx, + }} + /> + ); +}; + +export default FilterTag; diff --git a/frontend/src/components/shared/InfoIconWithTooltip.js b/frontend/src/components/shared/InfoIconWithTooltip.js new file mode 100644 index 0000000000000000000000000000000000000000..2b307ccaf8d7bebb91c81b2ff7cf746a4fbac05e --- /dev/null +++ b/frontend/src/components/shared/InfoIconWithTooltip.js @@ -0,0 +1,87 @@ +import React from "react"; +import { Box, Tooltip, Portal, Backdrop } from "@mui/material"; +import InfoOutlinedIcon from "@mui/icons-material/InfoOutlined"; + +const InfoIconWithTooltip = ({ tooltip, iconProps = {}, sx = {} }) => { + const [open, setOpen] = React.useState(false); + + return ( + <> + setOpen(true)} + onClose={() => setOpen(false)} + componentsProps={{ + tooltip: { + sx: { + bgcolor: "rgba(33, 33, 33, 0.95)", + padding: "12px 16px", + maxWidth: "none !important", + width: "auto", + minWidth: "200px", + fontSize: "0.875rem", + lineHeight: 1.5, + position: "relative", + zIndex: 1501, + "& .MuiTooltip-arrow": { + color: "rgba(33, 33, 33, 0.95)", + }, + }, + }, + popper: { + sx: { + zIndex: 1501, + maxWidth: "min(600px, 90vw) !important", + '&[data-popper-placement*="bottom"] .MuiTooltip-tooltip': { + marginTop: "10px", + }, + '&[data-popper-placement*="top"] .MuiTooltip-tooltip': { + marginBottom: "10px", + }, + }, + }, + }} + > + + + + + {open && ( + + + + )} + + ); +}; + +export default InfoIconWithTooltip; diff --git a/frontend/src/components/shared/PageHeader.js b/frontend/src/components/shared/PageHeader.js new file mode 100644 index 0000000000000000000000000000000000000000..4e3e255933e84a6c4e2354eff643277ee0256017 --- /dev/null +++ b/frontend/src/components/shared/PageHeader.js @@ -0,0 +1,29 @@ +import React from "react"; +import { Box, Typography } from "@mui/material"; + +const PageHeader = ({ title, subtitle }) => { + return ( + + + {title} + + {subtitle && ( + + {subtitle} + + )} + + ); +}; + +export default PageHeader; diff --git a/frontend/src/config/auth.js b/frontend/src/config/auth.js new file mode 100644 index 0000000000000000000000000000000000000000..f2df1da3a017040cefd75c8ef6c002903f93787b --- /dev/null +++ b/frontend/src/config/auth.js @@ -0,0 +1,7 @@ +export const HF_CONFIG = { + CLIENT_ID: "", + STORAGE_KEY: "hf_oauth", + SCOPE: "openid profile", + PROD_URL: "https://braindecode-eeg-finetune-arena.hf.space", + DEV_URL: "http://localhost:7860" +}; diff --git a/frontend/src/config/theme.js b/frontend/src/config/theme.js new file mode 100644 index 0000000000000000000000000000000000000000..4bd6e4ae0ac0810a89f7aafb480b3b12fbe0f524 --- /dev/null +++ b/frontend/src/config/theme.js @@ -0,0 +1,390 @@ +import { createTheme, alpha } from "@mui/material/styles"; + +const getDesignTokens = (mode) => ({ + typography: { + fontFamily: [ + "-apple-system", + "BlinkMacSystemFont", + '"Segoe UI"', + "Roboto", + '"Helvetica Neue"', + "Arial", + "sans-serif", + ].join(","), + h1: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + h2: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + h3: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + h4: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + h5: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + h6: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + subtitle1: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + subtitle2: { + fontFamily: '"Source Sans Pro", sans-serif', + }, + }, + palette: { + mode, + primary: { + main: "#4F86C6", + light: mode === "light" ? "#7BA7D7" : "#6B97D7", + dark: mode === "light" ? "#2B5C94" : "#3B6CA4", + 50: mode === "light" ? alpha("#4F86C6", 0.05) : alpha("#4F86C6", 0.15), + 100: mode === "light" ? alpha("#4F86C6", 0.1) : alpha("#4F86C6", 0.2), + 200: mode === "light" ? alpha("#4F86C6", 0.2) : alpha("#4F86C6", 0.3), + contrastText: "#fff", + }, + background: { + default: mode === "light" ? "#f8f9fa" : "#0a0a0a", + paper: mode === "light" ? "#fff" : "#1a1a1a", + subtle: mode === "light" ? "grey.100" : "grey.900", + hover: mode === "light" ? "action.hover" : alpha("#fff", 0.08), + tooltip: mode === "light" ? alpha("#212121", 0.9) : alpha("#fff", 0.9), + }, + text: { + primary: mode === "light" ? "rgba(0, 0, 0, 0.87)" : "#fff", + secondary: + mode === "light" ? "rgba(0, 0, 0, 0.6)" : "rgba(255, 255, 255, 0.7)", + disabled: + mode === "light" ? "rgba(0, 0, 0, 0.38)" : "rgba(255, 255, 255, 0.5)", + hint: + mode === "light" ? "rgba(0, 0, 0, 0.38)" : "rgba(255, 255, 255, 0.5)", + }, + divider: + mode === "light" ? "rgba(0, 0, 0, 0.12)" : "rgba(255, 255, 255, 0.12)", + action: { + active: + mode === "light" ? "rgba(0, 0, 0, 0.54)" : "rgba(255, 255, 255, 0.7)", + hover: + mode === "light" ? "rgba(0, 0, 0, 0.04)" : "rgba(255, 255, 255, 0.08)", + selected: + mode === "light" ? "rgba(0, 0, 0, 0.08)" : "rgba(255, 255, 255, 0.16)", + disabled: + mode === "light" ? "rgba(0, 0, 0, 0.26)" : "rgba(255, 255, 255, 0.3)", + disabledBackground: + mode === "light" ? "rgba(0, 0, 0, 0.12)" : "rgba(255, 255, 255, 0.12)", + }, + }, + shape: { + borderRadius: 8, + }, + components: { + MuiCssBaseline: { + styleOverrides: { + "html, body": { + backgroundColor: "background.default", + color: mode === "dark" ? "#fff" : "#000", + }, + body: { + "& *::-webkit-scrollbar": { + width: 8, + height: 8, + backgroundColor: "transparent", + }, + "& *::-webkit-scrollbar-thumb": { + borderRadius: 8, + backgroundColor: + mode === "light" ? alpha("#000", 0.2) : alpha("#fff", 0.1), + "&:hover": { + backgroundColor: + mode === "light" ? alpha("#000", 0.3) : alpha("#fff", 0.15), + }, + }, + }, + }, + }, + MuiButton: { + styleOverrides: { + root: { + borderRadius: 8, + }, + }, + }, + MuiPaper: { + defaultProps: { + elevation: 0, + }, + styleOverrides: { + root: { + backgroundImage: "none", + boxShadow: "none", + border: "1px solid", + borderColor: + mode === "light" + ? "rgba(0, 0, 0, 0.12)!important" + : "rgba(255, 255, 255, 0.25)!important", + }, + rounded: { + borderRadius: 12, + }, + }, + }, + + MuiTableCell: { + styleOverrides: { + root: { + borderColor: (theme) => + alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.1 : 0.2 + ), + }, + head: { + backgroundColor: mode === "light" ? "grey.50" : "grey.900", + color: "text.primary", + fontWeight: 600, + }, + }, + }, + MuiTableRow: { + styleOverrides: { + root: { + backgroundColor: "transparent", + }, + }, + }, + MuiTableContainer: { + styleOverrides: { + root: { + backgroundColor: "background.paper", + borderRadius: 8, + border: "none", + boxShadow: "none", + }, + }, + }, + MuiSlider: { + styleOverrides: { + root: { + "& .MuiSlider-valueLabel": { + backgroundColor: "background.paper", + color: "text.primary", + border: "1px solid", + borderColor: "divider", + boxShadow: + mode === "light" + ? "0px 2px 4px rgba(0, 0, 0, 0.1)" + : "0px 2px 4px rgba(0, 0, 0, 0.3)", + }, + }, + thumb: { + "&:hover": { + boxShadow: (theme) => + `0px 0px 0px 8px ${alpha( + theme.palette.primary.main, + mode === "light" ? 0.08 : 0.16 + )}`, + }, + "&.Mui-active": { + boxShadow: (theme) => + `0px 0px 0px 12px ${alpha( + theme.palette.primary.main, + mode === "light" ? 0.08 : 0.16 + )}`, + }, + }, + track: { + border: "none", + }, + rail: { + opacity: mode === "light" ? 0.38 : 0.3, + }, + mark: { + backgroundColor: mode === "light" ? "grey.400" : "grey.600", + }, + markLabel: { + color: "text.secondary", + }, + }, + }, + MuiTextField: { + styleOverrides: { + root: { + "& .MuiOutlinedInput-root": { + borderRadius: 8, + }, + }, + }, + }, + MuiChip: { + styleOverrides: { + root: { + borderRadius: 8, + }, + outlinedInfo: { + borderWidth: 2, + fontWeight: 600, + bgcolor: "info.100", + borderColor: "info.400", + color: "info.700", + "& .MuiChip-label": { + px: 1.2, + }, + "&:hover": { + bgcolor: "info.200", + }, + }, + outlinedWarning: { + borderWidth: 2, + fontWeight: 600, + bgcolor: "warning.100", + borderColor: "warning.400", + color: "warning.700", + "& .MuiChip-label": { + px: 1.2, + }, + "&:hover": { + bgcolor: "warning.200", + }, + }, + outlinedSuccess: { + borderWidth: 2, + fontWeight: 600, + bgcolor: "success.100", + borderColor: "success.400", + color: "success.700", + "& .MuiChip-label": { + px: 1.2, + }, + "&:hover": { + bgcolor: "success.200", + }, + }, + outlinedError: { + borderWidth: 2, + fontWeight: 600, + bgcolor: "error.100", + borderColor: "error.400", + color: "error.700", + "& .MuiChip-label": { + px: 1.2, + }, + "&:hover": { + bgcolor: "error.200", + }, + }, + outlinedPrimary: { + borderWidth: 2, + fontWeight: 600, + bgcolor: "primary.100", + borderColor: "primary.400", + color: "primary.700", + "& .MuiChip-label": { + px: 1.2, + }, + "&:hover": { + bgcolor: "primary.200", + }, + }, + outlinedSecondary: { + borderWidth: 2, + fontWeight: 600, + bgcolor: "secondary.100", + borderColor: "secondary.400", + color: "secondary.700", + "& .MuiChip-label": { + px: 1.2, + }, + "&:hover": { + bgcolor: "secondary.200", + }, + }, + }, + }, + MuiIconButton: { + styleOverrides: { + root: { + borderRadius: 8, + padding: "8px", + "&.MuiIconButton-sizeSmall": { + padding: "4px", + borderRadius: 6, + }, + }, + }, + }, + MuiTooltip: { + styleOverrides: { + tooltip: { + backgroundColor: + mode === "light" ? alpha("#212121", 0.9) : alpha("#424242", 0.9), + color: "#fff", + fontSize: "0.875rem", + padding: "8px 12px", + maxWidth: 400, + borderRadius: 8, + lineHeight: 1.4, + border: "1px solid", + borderColor: + mode === "light" ? alpha("#fff", 0.1) : alpha("#fff", 0.05), + boxShadow: + mode === "light" + ? "0 2px 8px rgba(0, 0, 0, 0.15)" + : "0 2px 8px rgba(0, 0, 0, 0.5)", + "& b": { + fontWeight: 600, + color: "inherit", + }, + "& a": { + color: mode === "light" ? "#90caf9" : "#64b5f6", + textDecoration: "none", + "&:hover": { + textDecoration: "underline", + }, + }, + }, + arrow: { + color: + mode === "light" ? alpha("#212121", 0.9) : alpha("#424242", 0.9), + "&:before": { + border: "1px solid", + borderColor: + mode === "light" ? alpha("#fff", 0.1) : alpha("#fff", 0.05), + }, + }, + }, + defaultProps: { + arrow: true, + enterDelay: 400, + leaveDelay: 200, + }, + }, + MuiAppBar: { + styleOverrides: { + root: { + border: "none", + borderBottom: "none", + }, + }, + }, + }, + breakpoints: { + values: { + xs: 0, + sm: 600, + md: 900, + lg: 1240, + xl: 1536, + }, + }, +}); + +const getTheme = (mode) => { + const tokens = getDesignTokens(mode); + return createTheme(tokens); +}; + +export default getTheme; diff --git a/frontend/src/hooks/useAuth.js b/frontend/src/hooks/useAuth.js new file mode 100644 index 0000000000000000000000000000000000000000..166d61aaaea425b8ec6e0c1d6bcf16311a94f369 --- /dev/null +++ b/frontend/src/hooks/useAuth.js @@ -0,0 +1,173 @@ +import { useState, useEffect } from "react"; +import { useLocation, useNavigate } from "react-router-dom"; +import { oauthLoginUrl, oauthHandleRedirectIfPresent } from "@huggingface/hub"; +import { HF_CONFIG } from "../config/auth"; + +async function fetchUserInfo(token) { + const response = await fetch("https://huggingface.co/api/whoami-v2", { + headers: { + Authorization: `Bearer ${token}`, + }, + }); + if (!response.ok) { + throw new Error("Failed to fetch user info"); + } + return response.json(); +} + +export function useAuth() { + const [isAuthenticated, setIsAuthenticated] = useState(false); + const [user, setUser] = useState(null); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + const location = useLocation(); + const navigate = useNavigate(); + + // Initialisation de l'authentification + useEffect(() => { + let mounted = true; + const initAuth = async () => { + try { + console.group("Auth Initialization"); + setLoading(true); + + // Vérifier s'il y a une redirection OAuth d'abord + let oauthResult = await oauthHandleRedirectIfPresent(); + + // Si pas de redirection, vérifier le localStorage + if (!oauthResult) { + const storedAuth = localStorage.getItem(HF_CONFIG.STORAGE_KEY); + if (storedAuth) { + try { + oauthResult = JSON.parse(storedAuth); + console.log("Found existing auth"); + const userInfo = await fetchUserInfo(oauthResult.access_token); + if (mounted) { + setIsAuthenticated(true); + setUser({ + username: userInfo.name, + token: oauthResult.access_token, + }); + } + } catch (err) { + console.log("Invalid stored auth data, clearing...", err); + localStorage.removeItem(HF_CONFIG.STORAGE_KEY); + if (mounted) { + setIsAuthenticated(false); + setUser(null); + } + } + } + } else { + console.log("Processing OAuth redirect"); + const token = oauthResult.accessToken; + const userInfo = await fetchUserInfo(token); + + const authData = { + access_token: token, + username: userInfo.name, + }; + + localStorage.setItem(HF_CONFIG.STORAGE_KEY, JSON.stringify(authData)); + + if (mounted) { + setIsAuthenticated(true); + setUser({ + username: userInfo.name, + token: token, + }); + } + + // Rediriger vers la page d'origine + const returnTo = localStorage.getItem("auth_return_to"); + if (returnTo) { + navigate(returnTo); + localStorage.removeItem("auth_return_to"); + } + } + } catch (err) { + console.error("Auth initialization error:", err); + if (mounted) { + setError(err.message); + setIsAuthenticated(false); + setUser(null); + } + } finally { + if (mounted) { + setLoading(false); + } + console.groupEnd(); + } + }; + + initAuth(); + + return () => { + mounted = false; + }; + }, [navigate, location.pathname]); + + const login = async () => { + try { + console.group("Login Process"); + setLoading(true); + + // Sauvegarder la route actuelle pour la redirection post-auth + const currentRoute = window.location.hash.replace("#", "") || "/"; + localStorage.setItem("auth_return_to", currentRoute); + + // Déterminer l'URL de redirection en fonction de l'environnement + const redirectUrl = + window.location.hostname === "localhost" || + window.location.hostname === "127.0.0.1" + ? HF_CONFIG.DEV_URL + : HF_CONFIG.PROD_URL; + + console.log("Using redirect URL:", redirectUrl); + + // Générer l'URL de login et rediriger + const loginUrl = await oauthLoginUrl({ + clientId: HF_CONFIG.CLIENT_ID, + redirectUrl, + scope: HF_CONFIG.SCOPE, + }); + + window.location.href = loginUrl + "&prompt=consent"; + + console.groupEnd(); + } catch (err) { + console.error("Login error:", err); + setError(err.message); + setLoading(false); + console.groupEnd(); + } + }; + + const logout = () => { + console.group("Logout Process"); + setLoading(true); + try { + console.log("Clearing auth data..."); + localStorage.removeItem(HF_CONFIG.STORAGE_KEY); + localStorage.removeItem("auth_return_to"); + setIsAuthenticated(false); + setUser(null); + console.log("Logged out successfully"); + } catch (err) { + console.error("Logout error:", err); + setError(err.message); + } finally { + setLoading(false); + console.groupEnd(); + } + }; + + return { + isAuthenticated, + user, + loading, + error, + login, + logout, + }; +} diff --git a/frontend/src/hooks/useThemeMode.js b/frontend/src/hooks/useThemeMode.js new file mode 100644 index 0000000000000000000000000000000000000000..93030109e2b32281c05178cc4207cb5544e94e4f --- /dev/null +++ b/frontend/src/hooks/useThemeMode.js @@ -0,0 +1,28 @@ +import { useState, useEffect } from 'react'; + +export const useThemeMode = () => { + // Get system preference + const getSystemPreference = () => { + return window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light'; + }; + + // Initialize theme mode from system preference + const [mode, setMode] = useState(getSystemPreference); + + // Listen to system preference changes + useEffect(() => { + const mediaQuery = window.matchMedia('(prefers-color-scheme: dark)'); + const handleChange = (e) => { + setMode(e.matches ? 'dark' : 'light'); + }; + + mediaQuery.addEventListener('change', handleChange); + return () => mediaQuery.removeEventListener('change', handleChange); + }, []); + + const toggleTheme = () => { + setMode((prevMode) => (prevMode === 'light' ? 'dark' : 'light')); + }; + + return { mode, toggleTheme }; +}; \ No newline at end of file diff --git a/frontend/src/index.js b/frontend/src/index.js new file mode 100644 index 0000000000000000000000000000000000000000..8db5acb8fb94a08138a3901be0b5b810c9e50931 --- /dev/null +++ b/frontend/src/index.js @@ -0,0 +1,10 @@ +import React from "react"; +import ReactDOM from "react-dom/client"; +import App from "./App"; + +const root = ReactDOM.createRoot(document.getElementById("root")); +root.render( + + + +); diff --git a/frontend/src/pages/AboutPage/AboutPage.js b/frontend/src/pages/AboutPage/AboutPage.js new file mode 100644 index 0000000000000000000000000000000000000000..7bcdba0ab740d25e78512ca618d1ae53db0d7149 --- /dev/null +++ b/frontend/src/pages/AboutPage/AboutPage.js @@ -0,0 +1,171 @@ +import React from "react"; +import { Box, Typography, Paper, Link } from "@mui/material"; +import PageHeader from "../../components/shared/PageHeader"; + +function AboutPage() { + return ( + + + + + + What is the EEG Finetune Arena? + + + The EEG Finetune Arena is an open leaderboard for evaluating and + comparing EEG (electroencephalography) models. We provide a + standardized evaluation pipeline to assess how well different models + perform on EEG-related tasks, enabling fair and reproducible + comparisons across the community. + + + Built on top of the{" "} + + braindecode + {" "} + library, the arena supports a variety of model architectures including + pretrained models, fine-tuned models, task-specific models, and + foundation models. + + + + + + Benchmarks + + + Models are currently evaluated on the following benchmarks: + + +
  • + + ANLI (Adversarial Natural Language Inference): + Tests the model's ability to perform natural language inference on + adversarially constructed examples, evaluating robustness and + reasoning capabilities. + +
  • +
  • + + LogiQA (Logical Reasoning QA): Evaluates logical + reasoning abilities through multiple-choice questions covering + categorical, conditional, disjunctive, and conjunctive reasoning. + +
  • +
    + + Additional EEG-specific benchmarks will be added as the arena evolves. + +
    + + + + Model Types + + +
  • + + {"\u{1F7E2}"} Pretrained: Base EEG models trained + with self-supervised learning on raw EEG data. + +
  • +
  • + + {"\u{1F536}"} Fine-tuned: Models fine-tuned on + specific EEG datasets for particular downstream tasks. + +
  • +
  • + + {"\u{1F9EA}"} Task-specific: Models designed for + specific EEG tasks such as sleep staging, motor imagery, or seizure + detection. + +
  • +
  • + + {"\u{1F3D7}\u{FE0F}"} Foundation: Large-scale EEG + foundation models trained on diverse EEG datasets. + +
  • +
    +
    + + + + Resources + + +
  • + + + Braindecode Documentation + {" "} + - The deep learning toolbox for EEG decoding + +
  • +
  • + + + Braindecode GitHub + {" "} + - Source code and contributions + +
  • +
    +
    +
    + ); +} + +export default AboutPage; diff --git a/frontend/src/pages/AddModelPage/AddModelPage.js b/frontend/src/pages/AddModelPage/AddModelPage.js new file mode 100644 index 0000000000000000000000000000000000000000..60f5d3130a71f25ed112bfe8a762c961b3dbd422 --- /dev/null +++ b/frontend/src/pages/AddModelPage/AddModelPage.js @@ -0,0 +1,51 @@ +import React from "react"; +import { Box, CircularProgress } from "@mui/material"; +import { useAuth } from "../../hooks/useAuth"; +import PageHeader from "../../components/shared/PageHeader"; +import EvaluationQueues from "./components/EvaluationQueues/EvaluationQueues"; +import ModelSubmissionForm from "./components/ModelSubmissionForm/ModelSubmissionForm"; +import SubmissionGuide from "./components/SubmissionGuide/SubmissionGuide"; +import SubmissionLimitChecker from "./components/SubmissionLimitChecker/SubmissionLimitChecker"; + +function AddModelPage() { + const { isAuthenticated, loading, user } = useAuth(); + + if (loading) { + return ( + + + + ); + } + + return ( + + + Add your model to the EEG + Finetune Arena + + } + /> + + + + + + + + + + ); +} + +export default AddModelPage; diff --git a/frontend/src/pages/AddModelPage/components/EvaluationQueues/EvaluationQueues.js b/frontend/src/pages/AddModelPage/components/EvaluationQueues/EvaluationQueues.js new file mode 100644 index 0000000000000000000000000000000000000000..c0b071d814c4fb7d8d1567221a086cf1396ff4ca --- /dev/null +++ b/frontend/src/pages/AddModelPage/components/EvaluationQueues/EvaluationQueues.js @@ -0,0 +1,787 @@ +import React, { useState, useEffect, useRef } from "react"; +import { + Box, + Typography, + Table, + TableBody, + TableCell, + TableContainer, + TableHead, + TableRow, + Chip, + Link, + CircularProgress, + Alert, + Accordion, + AccordionSummary, + AccordionDetails, + Stack, + Tooltip, + useTheme, + useMediaQuery, +} from "@mui/material"; +import AccessTimeIcon from "@mui/icons-material/AccessTime"; +import CheckCircleIcon from "@mui/icons-material/CheckCircle"; +import PendingIcon from "@mui/icons-material/Pending"; +import AutorenewIcon from "@mui/icons-material/Autorenew"; +import ExpandMoreIcon from "@mui/icons-material/ExpandMore"; +import OpenInNewIcon from "@mui/icons-material/OpenInNew"; +import { useVirtualizer } from "@tanstack/react-virtual"; + +// Function to format wait time +const formatWaitTime = (waitTimeStr) => { + const seconds = parseFloat(waitTimeStr.replace("s", "")); + + if (seconds < 60) { + return "just now"; + } + + const minutes = Math.floor(seconds / 60); + if (minutes < 60) { + return `${minutes}m ago`; + } + + const hours = Math.floor(minutes / 60); + if (hours < 24) { + return `${hours}h ago`; + } + + const days = Math.floor(hours / 24); + return `${days}d ago`; +}; + +// Column definitions with their properties +const columns = [ + { + id: "model", + label: "Model", + width: "35%", + align: "left", + }, + { + id: "submitter", + label: "Submitted by", + width: "15%", + align: "left", + }, + { + id: "wait_time", + label: "Submitted", + width: "12%", + align: "center", + }, + { + id: "precision", + label: "Precision", + width: "13%", + align: "center", + }, + { + id: "revision", + label: "Revision", + width: "12%", + align: "center", + }, + { + id: "status", + label: "Status", + width: "13%", + align: "center", + }, +]; + +const StatusChip = ({ status }) => { + const statusConfig = { + finished: { + icon: , + label: "Completed", + color: "success", + }, + evaluating: { + icon: , + label: "Evaluating", + color: "warning", + }, + pending: { icon: , label: "Pending", color: "info" }, + }; + + const config = statusConfig[status] || statusConfig.pending; + + return ( + + ); +}; + +const ModelTable = ({ models, emptyMessage, status }) => { + const parentRef = useRef(null); + const rowVirtualizer = useVirtualizer({ + count: models.length, + getScrollElement: () => parentRef.current, + estimateSize: () => 53, + overscan: 5, + }); + + if (models.length === 0) { + return ( + + {emptyMessage} + + ); + } + + return ( + + + + {columns.map((column) => ( + + ))} + + + + {columns.map((column, index) => ( + + {column.label} + + ))} + + + + + + <> + {rowVirtualizer.getVirtualItems().map((virtualRow) => { + const model = models[virtualRow.index]; + const waitTime = formatWaitTime(model.wait_time); + + return ( + + + + {model.name} + + + + + {model.submitter} + + + + + + {waitTime} + + + + + + {model.precision} + + + + {model.revision.substring(0, 7)} + + + + + + ); + })} + + + + +
    +
    + ); +}; + +const QueueAccordion = ({ + title, + models, + status, + emptyMessage, + expanded, + onChange, + loading, +}) => { + const theme = useTheme(); + const isMobile = useMediaQuery(theme.breakpoints.down("sm")); + + return ( + + } + sx={{ + px: { xs: 2, sm: 3 }, + py: { xs: 1.5, sm: 2 }, + alignItems: { xs: "flex-start", sm: "center" }, + "& .MuiAccordionSummary-expandIconWrapper": { + marginTop: { xs: "4px", sm: 0 }, + }, + }} + > + + + {title} + + + ({ + borderWidth: 2, + fontWeight: 600, + fontSize: { xs: "0.75rem", sm: "0.875rem" }, + height: { xs: "24px", sm: "32px" }, + width: { xs: "100%", sm: "auto" }, + bgcolor: + status === "finished" + ? theme.palette.success[100] + : status === "evaluating" + ? theme.palette.warning[100] + : theme.palette.info[100], + borderColor: + status === "finished" + ? theme.palette.success[400] + : status === "evaluating" + ? theme.palette.warning[400] + : theme.palette.info[400], + color: + status === "finished" + ? theme.palette.success[700] + : status === "evaluating" + ? theme.palette.warning[700] + : theme.palette.info[700], + "& .MuiChip-label": { + px: { xs: 1, sm: 1.2 }, + width: "100%", + }, + "&:hover": { + bgcolor: + status === "finished" + ? theme.palette.success[200] + : status === "evaluating" + ? theme.palette.warning[200] + : theme.palette.info[200], + }, + })} + /> + {loading && ( + + )} + + + + + + + + + + ); +}; + +const EvaluationQueues = ({ defaultExpanded = true }) => { + const [expanded, setExpanded] = useState(defaultExpanded); + const [expandedQueues, setExpandedQueues] = useState(new Set()); + const [models, setModels] = useState({ + pending: [], + evaluating: [], + finished: [], + }); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + const theme = useTheme(); + const isMobile = useMediaQuery(theme.breakpoints.down("sm")); + + useEffect(() => { + const fetchModels = async () => { + try { + const response = await fetch("/api/models/status"); + if (!response.ok) { + throw new Error("Failed to fetch models"); + } + const data = await response.json(); + + // Sort models by submission date (most recent first) + const sortByDate = (models) => { + return [...models].sort((a, b) => { + const dateA = new Date(a.submission_time); + const dateB = new Date(b.submission_time); + return dateB - dateA; + }); + }; + + setModels({ + finished: sortByDate(data.finished), + evaluating: sortByDate(data.evaluating), + pending: sortByDate(data.pending), + }); + } catch (err) { + setError(err.message); + } finally { + setLoading(false); + } + }; + + fetchModels(); + const interval = setInterval(fetchModels, 30000); + return () => clearInterval(interval); + }, []); + + const handleMainAccordionChange = (panel) => (event, isExpanded) => { + setExpanded(isExpanded ? panel : false); + }; + + const handleQueueAccordionChange = (queueName) => (event, isExpanded) => { + setExpandedQueues((prev) => { + const newSet = new Set(prev); + if (isExpanded) { + newSet.add(queueName); + } else { + newSet.delete(queueName); + } + return newSet; + }); + }; + + if (error) { + return ( + + {error} + + ); + } + + return ( + + } + sx={{ + px: { xs: 2, sm: 3 }, + "& .MuiAccordionSummary-expandIconWrapper": { + color: "text.secondary", + transform: "rotate(0deg)", + transition: "transform 150ms", + marginTop: { xs: "4px", sm: 0 }, + "&.Mui-expanded": { + transform: "rotate(180deg)", + }, + }, + }} + > + + + Evaluation Status + + {!loading && ( + + + + + + )} + {loading && ( + + )} + + + + {loading ? ( + + + + ) : ( + <> + + + + + + + )} + + + ); +}; + +export default EvaluationQueues; diff --git a/frontend/src/pages/AddModelPage/components/ModelSubmissionForm/ModelSubmissionForm.js b/frontend/src/pages/AddModelPage/components/ModelSubmissionForm/ModelSubmissionForm.js new file mode 100644 index 0000000000000000000000000000000000000000..07e077d569f5894945197293075ebdd65b75e2f7 --- /dev/null +++ b/frontend/src/pages/AddModelPage/components/ModelSubmissionForm/ModelSubmissionForm.js @@ -0,0 +1,545 @@ +import React, { useState } from "react"; +import { + Box, + Paper, + Typography, + TextField, + Button, + FormControl, + InputLabel, + Select, + MenuItem, + Stack, + Grid, + CircularProgress, + Alert, +} from "@mui/material"; +import RocketLaunchIcon from "@mui/icons-material/RocketLaunch"; +import CheckCircleOutlineIcon from "@mui/icons-material/CheckCircleOutline"; +import { alpha } from "@mui/material/styles"; +import InfoIconWithTooltip from "../../../../components/shared/InfoIconWithTooltip"; +import { MODEL_TYPES } from "../../../../pages/LeaderboardPage/components/Leaderboard/constants/modelTypes"; +import { SUBMISSION_PRECISIONS } from "../../../../pages/LeaderboardPage/components/Leaderboard/constants/defaults"; +import AuthContainer from "../../../../components/shared/AuthContainer"; + +const WEIGHT_TYPES = [ + { value: "Original", label: "Original" }, + { value: "Delta", label: "Delta" }, + { value: "Adapter", label: "Adapter" }, +]; + +const HELP_TEXTS = { + modelName: ( + + + Model Name on Hugging Face Hub + + + Your model must be public and loadable with AutoClasses without + trust_remote_code. The model should be in Safetensors format for better + safety and loading performance. Example: braindecode/EEGNetv4 + + + ), + revision: ( + + + Model Revision + + + Git branch, tag or commit hash. The evaluation will be strictly tied to + this specific commit to ensure consistency. Make sure this version is + stable and contains all necessary files. + + + ), + modelType: ( + + + Model Category + + + {"\u{1F7E2}"} Pretrained: Base EEG models trained with self-supervised learning{" "} + {"\u{1F536}"} Fine-tuned: Models fine-tuned on specific EEG datasets{" "} + {"\u{1F9EA}"} Task-specific: Models designed for specific EEG tasks{" "} + {"\u{1F3D7}\u{FE0F}"} Foundation: Large-scale EEG foundation models + + + ), + baseModel: ( + + + Base Model Reference + + + Required for delta weights or adapters. This information is used to + identify the original model and calculate the total parameter count by + combining base model and adapter/delta parameters. + + + ), + precision: ( + + + Model Precision + + + Size limits vary by precision: FP16/BF16: up to 500M parameters. + 8-bit: up to 1B parameters. 4-bit: up to 2B parameters. + Choose carefully as incorrect precision can cause evaluation errors. + + + ), + weightsType: ( + + + Weights Format + + + Original: Complete model weights in safetensors format Delta: Weight + differences from base model (requires base model for size calculation) + Adapter: Lightweight fine-tuning layers (requires base model for size + calculation) + + + ), +}; + +// Convert MODEL_TYPES to format expected by Select component +const modelTypeOptions = Object.entries(MODEL_TYPES).map( + ([value, { icon, label }]) => ({ + value, + label: `${icon} ${label}`, + }) +); + +function ModelSubmissionForm({ user, isAuthenticated }) { + const [formData, setFormData] = useState({ + modelName: "", + revision: "main", + modelType: "fine-tuned", + precision: "float16", + weightsType: "Original", + baseModel: "", + }); + const [error, setError] = useState(null); + const [submitting, setSubmitting] = useState(false); + const [success, setSuccess] = useState(false); + const [submittedData, setSubmittedData] = useState(null); + + const handleChange = (event) => { + const { name, value, checked } = event.target; + setFormData((prev) => ({ + ...prev, + [name]: event.target.type === "checkbox" ? checked : value, + })); + }; + + const handleSubmit = async (e) => { + e.preventDefault(); + setError(null); + setSubmitting(true); + + try { + const response = await fetch("/api/models/submit", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model_id: formData.modelName, + revision: formData.revision, + model_type: formData.modelType, + precision: formData.precision, + weight_type: formData.weightsType, + base_model: formData.baseModel, + user_id: user.username, + }), + }); + + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || "Failed to submit model"); + } + + setSubmittedData(formData); + setSuccess(true); + } catch (error) { + setError(error.message); + } finally { + setSubmitting(false); + } + }; + + if (success && submittedData) { + return ( + ({ + p: 6, + mb: 3, + bgcolor: alpha(theme.palette.success.main, 0.05), + borderColor: alpha(theme.palette.success.main, 0.2), + })} + > + + + + + Model submitted successfully! + + + + + Your model {submittedData.modelName} has been added + to the evaluation queue with the following parameters: + + + + + + + Model: + + + {submittedData.modelName} + + + + + Type: + + + {submittedData.modelType} + + + + + Revision: + + + {submittedData.revision} + + + + + Precision: + + + {submittedData.precision} + + + + + Weight type: + + + {submittedData.weightsType} + + + {submittedData.baseModel && ( + + + Base model: + + + {submittedData.baseModel} + + + )} + + + + + An automatic upvote has been added to your model to help with + prioritization. + + + + + + + + ); + } + + return ( + <> + {error && ( + + {error} + + )} + + {isAuthenticated && ( + + {/* Header */} + + theme.palette.mode === "dark" + ? alpha(theme.palette.divider, 0.1) + : "grey.200", + bgcolor: (theme) => + theme.palette.mode === "dark" + ? alpha(theme.palette.background.paper, 0.5) + : "grey.50", + }} + > + + Model Submission Form + + + + {/* Form Content */} + + + {/* Model Information */} + + + Model Information + + + + + + + ), + }} + /> + + + + + ), + }} + /> + + + {/* Model Configuration */} + + + Model Configuration + + + + + + Model Type + + + + + + + Precision + + + + + + + Weights Type + + + + + {formData.weightsType !== "Original" && ( + + + ), + }} + /> + + )} + + {/* Submit Button */} + + + + All fields marked with * are required + + + + + + + + )} + + ); +} + +export default ModelSubmissionForm; diff --git a/frontend/src/pages/AddModelPage/components/SubmissionGuide/SubmissionGuide.js b/frontend/src/pages/AddModelPage/components/SubmissionGuide/SubmissionGuide.js new file mode 100644 index 0000000000000000000000000000000000000000..5be7e4c01b63ab8a9110911c413ed2447e23639d --- /dev/null +++ b/frontend/src/pages/AddModelPage/components/SubmissionGuide/SubmissionGuide.js @@ -0,0 +1,281 @@ +import React, { useState, useEffect } from "react"; +import { useLocation, useNavigate } from "react-router-dom"; +import { Box, Paper, Typography, Button, Stack, Collapse } from "@mui/material"; +import ExpandMoreIcon from "@mui/icons-material/ExpandMore"; + +const DocLink = ({ href, children }) => ( + +); + +const StepNumber = ({ number }) => ( + + {number} + +); + +const TUTORIAL_STEPS = [ + { + title: "Model Information", + content: ( + + + Your model should be public on the Hub and follow the{" "} + username/model-id format (e.g. + braindecode/EEGNetv4). Specify the revision{" "} + (commit hash or branch) and model type. + + + Model uploading guide + + + ), + }, + { + title: "Technical Details", + content: ( + + + Make sure your model can be loaded locally before + submitting: + + + theme.palette.mode === "dark" ? "grey.50" : "grey.900", + borderRadius: 1, + "& pre": { + m: 0, + p: 0, + fontFamily: "monospace", + fontSize: "0.875rem", + color: (theme) => + theme.palette.mode === "dark" ? "grey.900" : "grey.50", + }, + }} + > +
    +            {`import braindecode
    +from braindecode.models import EEGNetv4
    +
    +# Load your model
    +model = EEGNetv4(
    +    n_chans=22,
    +    n_outputs=4,
    +    n_times=1000,
    +)
    +# Or from Hugging Face Hub
    +# model = braindecode.models.load("your-username/your-model", revision="main")`}
    +          
    +
    + + Braindecode documentation + +
    + ), + }, + { + title: "License Requirements", + content: ( + + + A license tag is required.{" "} + Open licenses (Apache, MIT, etc) are strongly + recommended. + + + About model licenses + + + ), + }, + { + title: "Model Card Requirements", + content: ( + + + Your model card must include: architecture,{" "} + training details,{" "} + dataset information (EEG paradigm, number of channels, + sampling rate), intended use, limitations, and{" "} + performance metrics. + + + Model cards guide + + + ), + }, + { + title: "Final Checklist", + content: ( + + + Ensure your model is public, uses{" "} + safetensors format, has a{" "} + license tag, and loads correctly{" "} + with the provided code. + + + Sharing best practices + + + ), + }, +]; + +function SubmissionGuide() { + const location = useLocation(); + const navigate = useNavigate(); + + // Initialize state directly with URL value + const initialExpanded = !new URLSearchParams(location.search).get("guide"); + const [expanded, setExpanded] = useState(initialExpanded); + + // Sync expanded state with URL changes after initial render + useEffect(() => { + const guideOpen = !new URLSearchParams(location.search).get("guide"); + if (guideOpen !== expanded) { + setExpanded(guideOpen); + } + }, [location.search, expanded]); + + const handleAccordionChange = () => { + const newExpanded = !expanded; + setExpanded(newExpanded); + const params = new URLSearchParams(location.search); + if (newExpanded) { + params.delete("guide"); + } else { + params.set("guide", "closed"); + } + navigate({ search: params.toString() }, { replace: true }); + }; + + return ( + + theme.palette.mode === "dark" ? "grey.800" : "grey.200", + overflow: "hidden", + }} + > + + theme.palette.mode === "dark" ? "grey.900" : "grey.50", + borderBottom: "1px solid", + borderColor: (theme) => + expanded + ? theme.palette.mode === "dark" + ? "grey.800" + : "grey.200" + : "transparent", + }} + > + + Submission Guide + + + + + + + {TUTORIAL_STEPS.map((step, index) => ( + + + + + + {step.title} + + + {step.content} + + {index < TUTORIAL_STEPS.length - 1 && ( + + theme.palette.mode === "dark" ? "grey.800" : "grey.100", + }} + /> + )} + + ))} + + + + + ); +} + +export default SubmissionGuide; diff --git a/frontend/src/pages/AddModelPage/components/SubmissionLimitChecker/SubmissionLimitChecker.js b/frontend/src/pages/AddModelPage/components/SubmissionLimitChecker/SubmissionLimitChecker.js new file mode 100644 index 0000000000000000000000000000000000000000..97f4a72884c5874e68a169ad3c9d6c1541c8852a --- /dev/null +++ b/frontend/src/pages/AddModelPage/components/SubmissionLimitChecker/SubmissionLimitChecker.js @@ -0,0 +1,85 @@ +import React, { useState, useEffect } from "react"; +import { Alert, Box, CircularProgress } from "@mui/material"; + +const MAX_SUBMISSIONS_PER_WEEK = 10; + +function SubmissionLimitChecker({ user, children }) { + const [loading, setLoading] = useState(true); + const [reachedLimit, setReachedLimit] = useState(false); + const [error, setError] = useState(false); + + useEffect(() => { + const checkSubmissionLimit = async () => { + if (!user?.username) { + setLoading(false); + return; + } + + try { + const response = await fetch( + `/api/models/organization/${user.username}/submissions?days=7` + ); + if (!response.ok) { + throw new Error("Failed to fetch submission data"); + } + + const submissions = await response.json(); + console.log(`Recent submissions for ${user.username}:`, submissions); + setReachedLimit(submissions.length >= MAX_SUBMISSIONS_PER_WEEK); + setError(false); + } catch (error) { + console.error("Error checking submission limit:", error); + setError(true); + } finally { + setLoading(false); + } + }; + + checkSubmissionLimit(); + }, [user?.username]); + + if (loading) { + return ( + + + + ); + } + + if (error) { + return ( + + Unable to verify submission limits. Please try again in a few minutes. + + ); + } + + if (reachedLimit) { + return ( + + For fairness reasons, you cannot submit more than{" "} + {MAX_SUBMISSIONS_PER_WEEK} models per week. Please try again later. + + ); + } + + return children; +} + +export default SubmissionLimitChecker; diff --git a/frontend/src/pages/LeaderboardPage/LeaderboardPage.js b/frontend/src/pages/LeaderboardPage/LeaderboardPage.js new file mode 100644 index 0000000000000000000000000000000000000000..f2cf76fbe617e3170ee52c3b303311aea8a5f725 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/LeaderboardPage.js @@ -0,0 +1,50 @@ +import { useEffect } from "react"; +import Leaderboard from "./components/Leaderboard/Leaderboard"; +import { Box } from "@mui/material"; +import PageHeader from "../../components/shared/PageHeader"; +import Logo from "../../components/Logo/Logo"; +import { useLeaderboardData } from "../../pages/LeaderboardPage/components/Leaderboard/hooks/useLeaderboardData"; +import { useLeaderboard } from "../../pages/LeaderboardPage/components/Leaderboard/context/LeaderboardContext"; + +function LeaderboardPage() { + const { data, isLoading, error } = useLeaderboardData(); + const { actions } = useLeaderboard(); + + useEffect(() => { + if (data) { + actions.setModels(data); + } + actions.setLoading(isLoading); + actions.setError(error); + }, [data, isLoading, error, actions]); + + return ( + + + + + + Comparing EEG models in an{" "} + open and{" "} + reproducible way + + } + /> + + + ); +} + +export default LeaderboardPage; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/Leaderboard.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/Leaderboard.js new file mode 100644 index 0000000000000000000000000000000000000000..5c41ce7fa5eeeb9b00bc657c174c9653c5d31503 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/Leaderboard.js @@ -0,0 +1,449 @@ +import React, { useMemo, useEffect, useCallback } from "react"; +import { Box, Typography } from "@mui/material"; +import { useSearchParams } from "react-router-dom"; + +import { TABLE_DEFAULTS } from "./constants/defaults"; +import { useLeaderboard } from "./context/LeaderboardContext"; +import { useLeaderboardProcessing } from "./hooks/useLeaderboardData"; +import { useLeaderboardData } from "./hooks/useLeaderboardData"; + +import LeaderboardFilters from "./components/Filters/Filters"; +import LeaderboardTable from "./components/Table/Table"; +import SearchBar, { SearchBarSkeleton } from "./components/Filters/SearchBar"; +import PerformanceMonitor from "./components/PerformanceMonitor"; +import QuickFilters, { + QuickFiltersSkeleton, +} from "./components/Filters/QuickFilters"; + +const FilterAccordion = ({ expanded, quickFilters, advancedFilters }) => { + const advancedFiltersRef = React.useRef(null); + const quickFiltersRef = React.useRef(null); + const [height, setHeight] = React.useState("auto"); + const resizeTimeoutRef = React.useRef(null); + + const updateHeight = React.useCallback(() => { + if (expanded && advancedFiltersRef.current) { + setHeight(`${advancedFiltersRef.current.scrollHeight}px`); + } else if (!expanded && quickFiltersRef.current) { + setHeight(`${quickFiltersRef.current.scrollHeight}px`); + } + }, [expanded]); + + React.useEffect(() => { + // Initial height calculation + const timer = setTimeout(updateHeight, 100); + + // Resize handler with debounce + const handleResize = () => { + if (resizeTimeoutRef.current) { + clearTimeout(resizeTimeoutRef.current); + } + resizeTimeoutRef.current = setTimeout(updateHeight, 150); + }; + + window.addEventListener("resize", handleResize); + + return () => { + clearTimeout(timer); + window.removeEventListener("resize", handleResize); + if (resizeTimeoutRef.current) { + clearTimeout(resizeTimeoutRef.current); + } + }; + }, [updateHeight]); + + // Update height when expanded state changes + React.useEffect(() => { + updateHeight(); + }, [expanded, updateHeight]); + + return ( + + + {quickFilters} + + + {advancedFilters} + + + ); +}; + +const Leaderboard = () => { + const { state, actions } = useLeaderboard(); + const [searchParams, setSearchParams] = useSearchParams(); + const { + data, + isLoading: dataLoading, + error: dataError, + } = useLeaderboardData(); + const { + table, + filteredData, + error: processingError, + } = useLeaderboardProcessing(); + + // Memoize filtered data + const memoizedFilteredData = useMemo(() => filteredData, [filteredData]); + const memoizedTable = useMemo(() => table, [table]); + + // Memoize table options + const hasTableOptionsChanges = useMemo(() => { + return ( + state.display.rowSize !== TABLE_DEFAULTS.ROW_SIZE || + JSON.stringify(state.display.scoreDisplay) !== + JSON.stringify(TABLE_DEFAULTS.SCORE_DISPLAY) || + state.display.averageMode !== TABLE_DEFAULTS.AVERAGE_MODE || + state.display.rankingMode !== TABLE_DEFAULTS.RANKING_MODE + ); + }, [state.display]); + + const hasColumnFilterChanges = useMemo(() => { + return ( + JSON.stringify([...state.display.visibleColumns].sort()) !== + JSON.stringify([...TABLE_DEFAULTS.COLUMNS.DEFAULT_VISIBLE].sort()) + ); + }, [state.display.visibleColumns]); + + // Memoize callbacks + const onToggleFilters = useCallback(() => { + actions.toggleFiltersExpanded(); + }, [actions]); + + const onColumnVisibilityChange = useCallback( + (newVisibility) => { + actions.setDisplayOption( + "visibleColumns", + Object.keys(newVisibility).filter((key) => newVisibility[key]) + ); + }, + [actions] + ); + + const onRowSizeChange = useCallback( + (size) => { + actions.setDisplayOption("rowSize", size); + }, + [actions] + ); + + const onScoreDisplayChange = useCallback( + (display) => { + actions.setDisplayOption("scoreDisplay", display); + }, + [actions] + ); + + const onAverageModeChange = useCallback( + (mode) => { + actions.setDisplayOption("averageMode", mode); + }, + [actions] + ); + + const onRankingModeChange = useCallback( + (mode) => { + actions.setDisplayOption("rankingMode", mode); + }, + [actions] + ); + + const onPrecisionsChange = useCallback( + (precisions) => { + actions.setFilter("precisions", precisions); + }, + [actions] + ); + + const onTypesChange = useCallback( + (types) => { + actions.setFilter("types", types); + }, + [actions] + ); + + const onParamsRangeChange = useCallback( + (range) => { + actions.setFilter("paramsRange", range); + }, + [actions] + ); + + const onBooleanFiltersChange = useCallback( + (filters) => { + actions.setFilter("booleanFilters", filters); + }, + [actions] + ); + + const onReset = useCallback(() => { + actions.resetFilters(); + }, [actions]); + + // Memoize loading states + const loadingStates = useMemo(() => { + const isInitialLoading = dataLoading || !data; + const isProcessingData = !memoizedTable || !memoizedFilteredData; + const isApplyingFilters = state.models.length > 0 && !memoizedFilteredData; + const hasValidFilterCounts = + state.countsReady && + state.filterCounts && + state.filterCounts.normal && + state.filterCounts.officialOnly; + + return { + isInitialLoading, + isProcessingData, + isApplyingFilters, + showSearchSkeleton: isInitialLoading || !hasValidFilterCounts, + showFiltersSkeleton: isInitialLoading || !hasValidFilterCounts, + showTableSkeleton: + isInitialLoading || + isProcessingData || + isApplyingFilters || + !hasValidFilterCounts, + }; + }, [ + dataLoading, + data, + memoizedTable, + memoizedFilteredData, + state.models.length, + state.filterCounts, + state.countsReady, + ]); + + // Memoize child components + const memoizedSearchBar = useMemo( + () => ( + + ), + [ + onToggleFilters, + state.filtersExpanded, + loadingStates.showTableSkeleton, + memoizedFilteredData, + table, + ] + ); + + const memoizedQuickFilters = useMemo( + () => ( + + ), + [state.models.length, memoizedFilteredData, memoizedTable] + ); + + const memoizedLeaderboardFilters = useMemo( + () => ( + + ), + [ + memoizedFilteredData, + loadingStates.showFiltersSkeleton, + state.filters.precisions, + state.filters.types, + state.filters.paramsRange, + state.filters.booleanFilters, + onPrecisionsChange, + onTypesChange, + onParamsRangeChange, + onBooleanFiltersChange, + onReset, + ] + ); + + // No need to memoize LeaderboardTable as it handles its own sorting state + const tableComponent = ( + + ); + + // Update context with loaded data + useEffect(() => { + if (data) { + actions.setModels(data); + } + }, [data, actions]); + + // Log to understand loading state + useEffect(() => { + if (process.env.NODE_ENV === "development") { + console.log("Loading state:", { + dataLoading, + hasData: !!data, + hasTable: !!table, + hasFilteredData: !!filteredData, + filteredDataLength: filteredData?.length, + stateModelsLength: state.models.length, + hasFilters: Object.keys(state.filters).some((key) => { + if (Array.isArray(state.filters[key])) { + return state.filters[key].length > 0; + } + return !!state.filters[key]; + }), + }); + } + }, [ + dataLoading, + data, + table, + filteredData?.length, + state.models.length, + filteredData, + state.filters, + ]); + + // If an error occurred, display it + if (dataError || processingError) { + return ( + + + {(dataError || processingError)?.message || + "An error occurred while loading the data"} + + + ); + } + + return ( + + + + + {loadingStates.showSearchSkeleton ? ( + + ) : ( + memoizedSearchBar + )} + + {loadingStates.showFiltersSkeleton ? ( + + ) : ( + + )} + + + + + + {tableComponent} + + + + + ); +}; + +export default Leaderboard; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/ColumnSelector/ColumnSelector.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/ColumnSelector/ColumnSelector.js new file mode 100644 index 0000000000000000000000000000000000000000..5a67cacd3d1d3343d22abcf7fd083440bcb94881 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/ColumnSelector/ColumnSelector.js @@ -0,0 +1,217 @@ +import React from "react"; +import { Box, Typography } from "@mui/material"; +import ViewColumnIcon from "@mui/icons-material/ViewColumn"; +import CloseIcon from "@mui/icons-material/Close"; +import FilterTag from "../../../../../../components/shared/FilterTag"; +import RestartAltIcon from "@mui/icons-material/RestartAlt"; +import { TABLE_DEFAULTS } from "../../constants/defaults"; +import DropdownButton from "../shared/DropdownButton"; +import InfoIconWithTooltip from "../../../../../../components/shared/InfoIconWithTooltip"; +import { UI_TOOLTIPS } from "../../constants/tooltips"; + +const FilterGroup = ({ title, children, count, total }) => ( + + + {title} + {count !== undefined && total !== undefined && ( + + ({count}/{total}) + + )} + + + {children} + + +); + +const ColumnSelector = ({ + table, + onReset, + hasChanges, + onColumnVisibilityChange, + loading = false, +}) => { + const { getState, setColumnVisibility } = table; + const { columnVisibility } = getState(); + + // Filter columns to only show filterable ones + const filterableColumns = [ + ...TABLE_DEFAULTS.COLUMNS.EVALUATION, + ...TABLE_DEFAULTS.COLUMNS.OPTIONAL, + ]; + + const handleReset = (e) => { + e.preventDefault(); + e.stopPropagation(); + + if (!hasChanges) return; + + // Call onReset first + onReset?.(); + + // Create object with all columns set to false by default + const defaultVisibility = {}; + + // Set to true all columns that should be visible by default + TABLE_DEFAULTS.COLUMNS.DEFAULT_VISIBLE.forEach((col) => { + defaultVisibility[col] = true; + }); + + onColumnVisibilityChange?.(defaultVisibility); + setColumnVisibility(defaultVisibility); + }; + + const toggleColumn = (columnId) => { + if (TABLE_DEFAULTS.COLUMNS.FIXED.includes(columnId)) return; + + const newVisibility = { + ...columnVisibility, + [columnId]: !columnVisibility[columnId], + }; + + setColumnVisibility(newVisibility); + onColumnVisibilityChange?.(newVisibility); + }; + + return ( + + + + + Column Visibility + + + + + + + Reset + + + + + {Object.entries(TABLE_DEFAULTS.COLUMNS.COLUMN_GROUPS).map( + ([groupTitle, columns]) => { + // Calculer le nombre de colonnes cochées pour les évaluations + const isEvalGroup = groupTitle === "Evaluation Scores"; + const filteredColumns = columns.filter((col) => + filterableColumns.includes(col) + ); + const checkedCount = isEvalGroup + ? filteredColumns.filter((col) => columnVisibility[col]).length + : undefined; + const totalCount = isEvalGroup ? filteredColumns.length : undefined; + + return ( + + {filteredColumns.map((columnName) => { + const isFixed = + TABLE_DEFAULTS.COLUMNS.FIXED.includes(columnName); + return ( + toggleColumn(columnName)} + disabled={isFixed} + variant="tag" + /> + ); + })} + + ); + } + )} + + ); +}; + +export default ColumnSelector; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/DisplayOptions/DisplayOptions.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/DisplayOptions/DisplayOptions.js new file mode 100644 index 0000000000000000000000000000000000000000..8ec6c2bf0b68a6f2372d867a5a6487128956fb4c --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/DisplayOptions/DisplayOptions.js @@ -0,0 +1,238 @@ +import React from "react"; +import { Box, Typography } from "@mui/material"; +import TuneIcon from "@mui/icons-material/Tune"; +import CloseIcon from "@mui/icons-material/Close"; +import RestartAltIcon from "@mui/icons-material/RestartAlt"; +import FilterTag from "../../../../../../components/shared/FilterTag"; +import { + TABLE_DEFAULTS, + ROW_SIZES, + SCORE_DISPLAY_OPTIONS, + RANKING_MODE_OPTIONS, +} from "../../constants/defaults"; +import { UI_TOOLTIPS } from "../../constants/tooltips"; +import DropdownButton from "../shared/DropdownButton"; +import InfoIconWithTooltip from "../../../../../../components/shared/InfoIconWithTooltip"; + +const TableOptions = ({ + rowSize, + onRowSizeChange, + scoreDisplay = "normalized", + onScoreDisplayChange, + averageMode = "all", + onAverageModeChange, + rankingMode = "static", + onRankingModeChange, + hasChanges, + searchParams, + setSearchParams, + loading = false, +}) => { + const handleReset = () => { + onRowSizeChange(TABLE_DEFAULTS.ROW_SIZE); + onScoreDisplayChange(TABLE_DEFAULTS.SCORE_DISPLAY); + onAverageModeChange(TABLE_DEFAULTS.AVERAGE_MODE); + onRankingModeChange(TABLE_DEFAULTS.RANKING_MODE); + + const newParams = new URLSearchParams(searchParams); + ["rowSize", "scoreDisplay", "averageMode", "rankingMode"].forEach( + (param) => { + newParams.delete(param); + } + ); + setSearchParams(newParams); + }; + + return ( + + + + + Table Options + + + + + + + Reset + + + + + + + + + + {UI_TOOLTIPS.ROW_SIZE.title} + + + + + {Object.keys(ROW_SIZES).map((size) => ( + onRowSizeChange(size)} + variant="tag" + /> + ))} + + + + + + + {UI_TOOLTIPS.SCORE_DISPLAY.title} + + + + + {SCORE_DISPLAY_OPTIONS.map(({ value, label }) => ( + onScoreDisplayChange(value)} + variant="tag" + /> + ))} + + + + + + + {UI_TOOLTIPS.RANKING_MODE.title} + + + + + {RANKING_MODE_OPTIONS.map(({ value, label }) => ( + onRankingModeChange(value)} + variant="tag" + /> + ))} + + + + + + + {UI_TOOLTIPS.AVERAGE_SCORE.title} + + + + + onAverageModeChange("all")} + variant="tag" + /> + onAverageModeChange("visible")} + variant="tag" + /> + + + + + + ); +}; + +export default TableOptions; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/FilteredModelCount.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/FilteredModelCount.js new file mode 100644 index 0000000000000000000000000000000000000000..f35223166eb572d3d09527bd60129a006d85f7c8 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/FilteredModelCount.js @@ -0,0 +1,246 @@ +import React from "react"; +import { Box, Typography, Skeleton } from "@mui/material"; +import { useMemo } from "react"; +import { useLeaderboard } from "../../context/LeaderboardContext"; + +const useModelCount = ({ totalCount, filteredCount, data, table, loading }) => { + const { state } = useLeaderboard(); + const isOfficialProviderActive = state.filters.isOfficialProviderActive; + const { officialOnly: officialOnlyCounts } = state.filterCounts; + + return useMemo(() => { + if (loading) { + return { + displayCount: 0, + currentFilteredCount: 0, + totalPinnedCount: 0, + filteredPinnedCount: 0, + isOfficialProviderActive, + }; + } + const displayCount = isOfficialProviderActive + ? officialOnlyCounts.officialProviders + : totalCount; + + // Calculate total number of pinned models + const totalPinnedCount = + data?.filter((model) => model.isPinned)?.length || 0; + + // Get current filter criteria + const filterConfig = { + selectedPrecisions: state.filters.precisions, + selectedTypes: state.filters.types, + paramsRange: state.filters.paramsRange, + searchValue: state.filters.search, + selectedBooleanFilters: state.filters.booleanFilters, + isOfficialProviderActive: state.filters.isOfficialProviderActive, + }; + + // Check each pinned model if it would pass filters without its pinned status + const filteredPinnedCount = + data?.filter((model) => { + if (!model.isPinned) return false; + + // Check each filter criteria + + // Filter by official providers + if (filterConfig.isOfficialProviderActive) { + if ( + !model.features?.is_official_provider && + !model.metadata?.is_official_provider + ) { + return false; + } + } + + // Filter by precision + if (filterConfig.selectedPrecisions.length > 0) { + if ( + !filterConfig.selectedPrecisions.includes(model.model.precision) + ) { + return false; + } + } + + // Filter by type + if (filterConfig.selectedTypes.length > 0) { + const modelType = model.model.type?.toLowerCase().trim(); + if ( + !filterConfig.selectedTypes.some((type) => + modelType?.includes(type) + ) + ) { + return false; + } + } + + // Filter by parameters + const params = model.metadata.params_billions; + if ( + params < filterConfig.paramsRange[0] || + params >= filterConfig.paramsRange[1] + ) { + return false; + } + + // Filter by search + if (filterConfig.searchValue) { + const searchLower = filterConfig.searchValue.toLowerCase(); + const modelName = model.model.name.toLowerCase(); + if (!modelName.includes(searchLower)) { + return false; + } + } + + // Filter by boolean flags + if (filterConfig.selectedBooleanFilters.length > 0) { + if ( + !filterConfig.selectedBooleanFilters.every((filter) => { + const filterValue = + typeof filter === "object" ? filter.value : filter; + + // Maintainer's Highlight keeps positive logic + if (filterValue === "is_official_provider") { + return model.features[filterValue]; + } + + // For all other filters, invert the logic + if (filterValue === "is_not_available_on_hub") { + return model.features[filterValue]; + } + + return !model.features[filterValue]; + }) + ) { + return false; + } + } + + // If we get here, the model passes all filters + return true; + })?.length || 0; + + return { + displayCount, + currentFilteredCount: filteredCount, + totalPinnedCount, + filteredPinnedCount, + isOfficialProviderActive, + }; + }, [ + loading, + totalCount, + filteredCount, + data, + state.filters, + isOfficialProviderActive, + officialOnlyCounts.officialProviders, + ]); +}; + +const CountTypography = ({ + value, + color = "text.primary", + loading = false, + pinnedCount = 0, + filteredPinnedCount = 0, + showPinned = false, +}) => { + if (loading) { + return ( + + ); + } + + return ( + + + {value} + + {showPinned && pinnedCount > 0 && ( + + {`+${pinnedCount}`} + + )} + + ); +}; + +const FilteredModelCount = React.memo( + ({ + totalCount = 0, + filteredCount = 0, + hasFilterChanges = false, + loading = false, + data = [], + table = null, + }) => { + const { + displayCount, + currentFilteredCount, + totalPinnedCount, + filteredPinnedCount, + isOfficialProviderActive, + } = useModelCount({ + totalCount, + filteredCount, + data, + table, + loading, + }); + + const shouldHighlight = + !loading && hasFilterChanges && currentFilteredCount !== displayCount; + + // Always show pinned models when they exist + const pinnedToShow = totalPinnedCount; + + return ( + + 0} + /> + + + + ); + } +); + +FilteredModelCount.displayName = "FilteredModelCount"; + +export default FilteredModelCount; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/Filters.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/Filters.js new file mode 100644 index 0000000000000000000000000000000000000000..1fa0572d69fee9212d4bcd01058fc7acdb4d1de2 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/Filters.js @@ -0,0 +1,850 @@ +import React, { + useState, + useEffect, + useMemo, + useRef, + forwardRef, + useCallback, +} from "react"; +import { + Box, + Typography, + Collapse, + Slider, + Grid, + Accordion, + AccordionDetails, + alpha, + useTheme, + TextField, +} from "@mui/material"; +import { + TABLE_DEFAULTS, + BOOLEAN_FILTER_OPTIONS, + FILTER_PRECISIONS, +} from "../../constants/defaults"; +import FilterTag from "../../../../../../components/shared/FilterTag"; +import { MODEL_TYPE_ORDER, MODEL_TYPES } from "../../constants/modelTypes"; +import { useLeaderboard } from "../../context/LeaderboardContext"; +import InfoIconWithTooltip from "../../../../../../components/shared/InfoIconWithTooltip"; +import { COLUMN_TOOLTIPS } from "../../constants/tooltips"; + +const getTooltipContent = (title) => { + switch (title) { + case "Model Type": + return COLUMN_TOOLTIPS.ARCHITECTURE; + case "Precision format": + return COLUMN_TOOLTIPS.PRECISION; + case "Flags": + return COLUMN_TOOLTIPS.FLAGS; + case "Parameters": + return COLUMN_TOOLTIPS.PARAMETERS; + default: + return null; + } +}; + +const FilterGroup = ({ + title, + tooltip, + children, + paramsRange, + onParamsRangeChange, +}) => { + const theme = useTheme(); + const [localParamsRange, setLocalParamsRange] = useState(paramsRange); + const stableTimerRef = useRef(null); + + // Handle local range change + const handleLocalRangeChange = useCallback((event, newValue) => { + setLocalParamsRange(newValue); + }, []); + + // Handle input change + const handleInputChange = useCallback( + (index) => (event) => { + const value = event.target.value === "" ? "" : Number(event.target.value); + if (value === "" || (value >= -1 && value <= 140)) { + const newRange = [...localParamsRange]; + newRange[index] = value; + setLocalParamsRange(newRange); + } + }, + [localParamsRange] + ); + + // Sync local state with props + useEffect(() => { + setLocalParamsRange(paramsRange); + }, [paramsRange]); + + // Propagate changes to parent after delay + useEffect(() => { + if (stableTimerRef.current) { + clearTimeout(stableTimerRef.current); + } + + stableTimerRef.current = setTimeout(() => { + if (Array.isArray(localParamsRange) && localParamsRange.length === 2) { + onParamsRangeChange(localParamsRange); + } + }, 300); + + return () => { + if (stableTimerRef.current) { + clearTimeout(stableTimerRef.current); + } + }; + }, [localParamsRange, onParamsRangeChange]); + + const renderContent = () => { + if (title === "Parameters") { + return ( + + + + + + + + (value === -1 ? "All" : `${value}B`)} + sx={{ + "& .MuiSlider-rail": { + height: 10, + backgroundColor: "background.paper", + border: "1px solid", + borderColor: "divider", + opacity: 1, + }, + "& .MuiSlider-track": { + height: 10, + border: "1px solid", + borderColor: (theme) => + alpha( + theme.palette.primary.main, + theme.palette.mode === "light" ? 0.3 : 0.5 + ), + backgroundColor: (theme) => + alpha( + theme.palette.primary.main, + theme.palette.mode === "light" ? 0.1 : 0.2 + ), + }, + "& .MuiSlider-thumb": { + width: 20, + height: 20, + backgroundColor: "background.paper", + border: "1px solid", + borderColor: "primary.main", + "&:hover, &.Mui-focusVisible": { + boxShadow: (theme) => + `0 0 0 8px ${alpha( + theme.palette.primary.main, + theme.palette.mode === "light" ? 0.08 : 0.16 + )}`, + }, + "&.Mui-active": { + boxShadow: (theme) => + `0 0 0 12px ${alpha( + theme.palette.primary.main, + theme.palette.mode === "light" ? 0.08 : 0.16 + )}`, + }, + }, + "& .MuiSlider-valueLabel": { + backgroundColor: theme.palette.primary.main, + }, + "& .MuiSlider-mark": { + width: 2, + height: 10, + backgroundColor: "divider", + }, + "& .MuiSlider-markLabel": { + fontSize: "0.875rem", + "&::after": { + content: '"B"', + marginLeft: "1px", + opacity: 0.5, + }, + '&[data-index="0"]::after': { + content: '""', + }, + }, + }} + /> + + ); + } + return ( + + {children} + + ); + }; + + return ( + + + + {title} + + + + {renderContent()} + + ); +}; + +const CustomCollapse = forwardRef((props, ref) => { + const { children, style = {}, ...other } = props; + const collapsedHeight = "0px"; + const timeout = 300; + + const wrapperRef = useRef(null); + const [animatedHeight, setAnimatedHeight] = useState( + props.in ? "auto" : collapsedHeight + ); + + useEffect(() => { + if (!wrapperRef.current) return; + + if (props.in) { + const contentHeight = wrapperRef.current.scrollHeight; + setAnimatedHeight(`${contentHeight}px`); + } else { + setAnimatedHeight(collapsedHeight); + } + }, [props.in, children]); + + const handleEntered = (node) => { + setAnimatedHeight("auto"); + if (props.onEntered) { + props.onEntered(node); + } + }; + + return ( + +
    {children}
    +
    + ); +}); + +const LeaderboardFilters = ({ + selectedPrecisions = FILTER_PRECISIONS, + onPrecisionsChange = () => {}, + selectedTypes = MODEL_TYPE_ORDER, + onTypesChange = () => {}, + paramsRange = [-1, 140], + onParamsRangeChange = () => {}, + selectedBooleanFilters = [], + onBooleanFiltersChange = () => {}, + data = [], + expanded, + onToggleExpanded, + loading = false, +}) => { + const [localParamsRange, setLocalParamsRange] = useState(paramsRange); + const stableTimerRef = useRef(null); + const { state, actions } = useLeaderboard(); + const { normal: filterCounts, officialOnly: officialOnlyCounts } = + state.filterCounts; + const isOfficialProviderActive = state.filters.isOfficialProviderActive; + const currentCounts = useMemo( + () => (isOfficialProviderActive ? officialOnlyCounts : filterCounts), + [isOfficialProviderActive, officialOnlyCounts, filterCounts] + ); + + useEffect(() => { + setLocalParamsRange(paramsRange); + }, [paramsRange]); + + // Clean up timer when component unmounts + useEffect(() => { + return () => { + if (stableTimerRef.current) { + clearTimeout(stableTimerRef.current); + } + }; + }, []); + + const handleParamsRangeChange = (event, newValue) => { + setLocalParamsRange(newValue); + }; + + const handleParamsRangeChangeCommitted = (event, newValue) => { + // Reset timer on each change + if (stableTimerRef.current) { + clearTimeout(stableTimerRef.current); + } + + // Update URL immediately + onParamsRangeChange(newValue); + + // Trigger data update after debounce + stableTimerRef.current = setTimeout(() => { + actions.updateFilteredData(); + }, TABLE_DEFAULTS.DEBOUNCE.SEARCH); + }; + + const handlePrecisionToggle = (precision) => { + const newPrecisions = selectedPrecisions.includes(precision) + ? selectedPrecisions.filter((p) => p !== precision) + : [...selectedPrecisions, precision]; + onPrecisionsChange(newPrecisions); + }; + + const handleBooleanFilterToggle = (filter) => { + const newFilters = selectedBooleanFilters.includes(filter) + ? selectedBooleanFilters.filter((f) => f !== filter) + : [...selectedBooleanFilters, filter]; + onBooleanFiltersChange(newFilters); + }; + + // Filter options based on their hide property + const showFilterOptions = BOOLEAN_FILTER_OPTIONS.filter( + (option) => !option.hide + ); + const hideFilterOptions = BOOLEAN_FILTER_OPTIONS.filter( + (option) => option.hide + ); + + const handleOfficialProviderToggle = () => { + actions.toggleOfficialProvider(); + }; + + return loading ? null : ( + + + + + + + + alpha(theme.palette.primary.main, 0.02), + border: "1px solid", + borderColor: (theme) => + alpha(theme.palette.primary.main, 0.2), + borderRadius: 1, + p: 3, + position: "relative", + width: "100%", + display: "flex", + flexDirection: "column", + "&:hover": { + borderColor: (theme) => + alpha(theme.palette.primary.main, 0.3), + backgroundColor: (theme) => + alpha(theme.palette.primary.main, 0.03), + }, + transition: (theme) => + theme.transitions.create( + ["border-color", "background-color"], + { + duration: theme.transitions.duration.short, + } + ), + }} + > + + Advanced Filters + + + + + + + {FILTER_PRECISIONS.map((precision) => ( + + handlePrecisionToggle(precision) + } + count={currentCounts.precisions[precision]} + showCheckbox={true} + /> + ))} + + + + + + + + + + + alpha( + theme.palette.primary.main, + theme.palette.mode === "light" + ? 0.3 + : 0.5 + ), + backgroundColor: (theme) => + alpha( + theme.palette.primary.main, + theme.palette.mode === "light" + ? 0.1 + : 0.2 + ), + }, + "& .MuiSlider-thumb": { + width: 20, + height: 20, + backgroundColor: "background.paper", + border: "1px solid", + borderColor: "primary.main", + "&:hover, &.Mui-focusVisible": { + boxShadow: (theme) => + `0 0 0 8px ${alpha( + theme.palette.primary.main, + theme.palette.mode === "light" + ? 0.08 + : 0.16 + )}`, + }, + "&.Mui-active": { + boxShadow: (theme) => + `0 0 0 12px ${alpha( + theme.palette.primary.main, + theme.palette.mode === "light" + ? 0.08 + : 0.16 + )}`, + }, + }, + "& .MuiSlider-mark": { + backgroundColor: "text.disabled", + height: 2, + width: 2, + borderRadius: "50%", + }, + "& .MuiSlider-markLabel": { + color: "text.secondary", + }, + }} + /> + + + + + + + {/* Deuxième ligne */} + + + + {MODEL_TYPE_ORDER.sort( + (a, b) => + MODEL_TYPES[a].order - MODEL_TYPES[b].order + ).map((type) => ( + { + const newTypes = selectedTypes.includes(type) + ? selectedTypes.filter((t) => t !== type) + : [...selectedTypes, type]; + onTypesChange(newTypes); + }} + count={currentCounts.modelTypes[type]} + variant="tag" + showCheckbox={true} + /> + ))} + + + + + + + + {hideFilterOptions.map((filter) => ( + { + const newFilters = + selectedBooleanFilters.includes( + filter.value + ) + ? selectedBooleanFilters.filter( + (f) => f !== filter.value + ) + : [ + ...selectedBooleanFilters, + filter.value, + ]; + onBooleanFiltersChange(newFilters); + }} + count={ + filter.value === "is_moe" + ? currentCounts.mixtureOfExperts + : filter.value === "is_flagged" + ? currentCounts.flagged + : filter.value === "is_merged" + ? currentCounts.merged + : filter.value === "is_not_available_on_hub" + ? currentCounts.notOnHub + : 0 + } + isHideFilter={false} + totalCount={data.length} + showCheckbox={true} + /> + ))} + + + + + + + + + + + alpha(theme.palette.secondary.main, 0.02), + border: "1px solid", + borderColor: (theme) => + alpha(theme.palette.secondary.main, 0.15), + borderRadius: 1, + p: 3, + position: "relative", + width: "100%", + display: "flex", + flexDirection: "column", + alignItems: "center", + justifyContent: "center", + textAlign: "center", + minHeight: "100%", + "&:hover": { + borderColor: (theme) => + alpha(theme.palette.secondary.main, 0.25), + backgroundColor: (theme) => + alpha(theme.palette.secondary.main, 0.03), + }, + transition: (theme) => + theme.transitions.create( + ["border-color", "background-color"], + { + duration: theme.transitions.duration.short, + } + ), + }} + > + + + Official Models + + + Show only models that are officially provided and + maintained by their original creators. + + + {showFilterOptions.map((filter) => ( + + handleBooleanFilterToggle(filter.value) + } + count={ + filter.value === "is_official_provider" + ? currentCounts.officialProviders + : 0 + } + showCheckbox={true} + variant="secondary" + /> + + + {( + filter.value === "is_official_provider" + ? isOfficialProviderActive + : selectedBooleanFilters.includes(filter.value) + ) + ? "Filter active" + : "Filter inactive"} + + + ))} + + + + + + + + + + ); +}; + +export default LeaderboardFilters; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/QuickFilters.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/QuickFilters.js new file mode 100644 index 0000000000000000000000000000000000000000..91d074c6375e8129eda09cea299b6aa36e26c3f9 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/QuickFilters.js @@ -0,0 +1,226 @@ +import React, { useCallback, useMemo } from "react"; +import { Box, Typography, Skeleton } from "@mui/material"; +import { alpha } from "@mui/material/styles"; +import { QUICK_FILTER_PRESETS } from "../../constants/quickFilters"; +import FilterTag from "../../../../../../components/shared/FilterTag"; +import { useLeaderboard } from "../../context/LeaderboardContext"; +import InfoIconWithTooltip from "../../../../../../components/shared/InfoIconWithTooltip"; +import { UI_TOOLTIPS } from "../../constants/tooltips"; + +const QuickFiltersTitle = ({ sx = {} }) => ( + + + Quick Filters + + + +); + +export const QuickFiltersSkeleton = () => ( + + ({ + xs: alpha(theme.palette.primary.main, 0.02), + lg: "transparent", + }), + borderColor: (theme) => ({ + xs: alpha(theme.palette.primary.main, 0.2), + lg: "transparent", + }), + border: "1px solid", + borderRadius: 1, + p: 3, + display: "flex", + flexDirection: { xs: "column", md: "column", lg: "row" }, + gap: 2, + mb: 2, + width: "100%", + }} + > + + + {[1, 2, 3, 4].map((i) => ( + + ))} + + + +); + +const QuickFilters = ({ totalCount = 0, loading = false }) => { + const { state, actions } = useLeaderboard(); + const { normal: filterCounts, officialOnly: officialOnlyCounts } = + state.filterCounts; + const isOfficialProviderActive = state.filters.isOfficialProviderActive; + const currentParams = state.filters.paramsRange; + + const currentCounts = useMemo( + () => (isOfficialProviderActive ? officialOnlyCounts : filterCounts), + [isOfficialProviderActive, officialOnlyCounts, filterCounts] + ); + + const modelSizePresets = useMemo( + () => + QUICK_FILTER_PRESETS.filter( + (preset) => preset.id !== "official_providers" + ), + [] + ); + + const officialProvidersPreset = useMemo( + () => + QUICK_FILTER_PRESETS.find((preset) => preset.id === "official_providers"), + [] + ); + + const handleSizePresetClick = useCallback( + (preset) => { + const isActive = + currentParams[0] === preset.filters.paramsRange[0] && + currentParams[1] === preset.filters.paramsRange[1]; + + if (isActive) { + actions.setFilter("paramsRange", [-1, 140]); // Reset to default + } else { + actions.setFilter("paramsRange", preset.filters.paramsRange); + } + }, + [currentParams, actions] + ); + + const getPresetCount = useCallback( + (preset) => { + const range = preset.id.split("_")[0]; + return currentCounts.parameterRanges[range] || 0; + }, + [currentCounts] + ); + + const handleOfficialProviderToggle = useCallback(() => { + actions.toggleOfficialProvider(); + }, [actions]); + + if (loading) { + return ; + } + + return ( + + ({ + xs: alpha(theme.palette.primary.main, 0.02), + lg: "transparent", + }), + borderColor: (theme) => ({ + xs: alpha(theme.palette.primary.main, 0.2), + lg: "transparent", + }), + border: "1px solid", + borderRadius: 1, + p: 3, + display: "flex", + flexDirection: { xs: "column", lg: "row" }, + alignItems: "center", + gap: 2, + width: "100%", + }} + > + + + + div": { + width: { xs: "100%", md: 0, lg: "auto" }, + flex: { + xs: "auto", + md: "1 1 0", + lg: "0 0 auto", + }, + }, + }} + > + {modelSizePresets.map((preset) => ( + handleSizePresetClick(preset)} + count={getPresetCount(preset)} + totalCount={totalCount} + /> + ))} + + + + {officialProvidersPreset && ( + + )} + + + + ); +}; + +QuickFilters.displayName = "QuickFilters"; + +export default React.memo(QuickFilters); diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/SearchBar.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/SearchBar.js new file mode 100644 index 0000000000000000000000000000000000000000..c32cd8f8640b0d2e8fa7c1928f76fcd8d53fe494 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/SearchBar.js @@ -0,0 +1,329 @@ +import React, { useState, useEffect } from "react"; +import { Box, InputBase, Typography, Paper, Skeleton } from "@mui/material"; + +import SearchIcon from "@mui/icons-material/Search"; +import FilterListIcon from "@mui/icons-material/FilterList"; +import RestartAltIcon from "@mui/icons-material/RestartAlt"; +import { useTheme } from "@mui/material/styles"; +import { generateSearchDescription } from "../../utils/searchUtils"; +import { + HIGHLIGHT_COLORS, + TABLE_DEFAULTS, + FILTER_PRECISIONS, +} from "../../constants/defaults"; +import { MODEL_TYPE_ORDER } from "../../constants/modelTypes"; +import { alpha } from "@mui/material/styles"; +import FilteredModelCount from "./FilteredModelCount"; +import { useLeaderboard } from "../../context/LeaderboardContext"; +import InfoIconWithTooltip from "../../../../../../components/shared/InfoIconWithTooltip"; +import { UI_TOOLTIPS } from "../../constants/tooltips"; + +export const SearchBarSkeleton = () => ( + + alpha(theme.palette.background.paper, 0.8), + borderRadius: 1, + border: (theme) => + `1px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.05 : 0.1 + )}`, + display: "flex", + alignItems: "center", + px: 2, + gap: 2, + }} + > + + + + + + + + + + + Supports strict search and regex • Use semicolons for multiple terms + + + +); + +const SearchDescription = ({ searchValue }) => { + const searchGroups = generateSearchDescription(searchValue); + + if (!searchGroups || searchGroups.length === 0) return null; + + return ( + + + Showing models matching: + + {searchGroups.map(({ text, index }, i) => ( + + {i > 0 && ( + + and + + )} + + theme.palette.getContrastText( + HIGHLIGHT_COLORS[index % HIGHLIGHT_COLORS.length] + ), + padding: "2px 4px", + borderRadius: "4px", + fontSize: "0.85rem", + fontWeight: 500, + }} + > + {text} + + + ))} + + ); +}; + +const SearchBar = ({ + onToggleFilters, + filtersOpen, + loading = false, + data = [], + table = null, +}) => { + const theme = useTheme(); + const { state, actions } = useLeaderboard(); + const [localValue, setLocalValue] = useState(state.filters.search); + + useEffect(() => { + setLocalValue(state.filters.search); + }, [state.filters.search]); + + useEffect(() => { + const timer = setTimeout(() => { + if (localValue !== state.filters.search) { + actions.setFilter("search", localValue); + } + }, TABLE_DEFAULTS.DEBOUNCE.SEARCH); + + return () => clearTimeout(timer); + }, [localValue, state.filters.search, actions]); + + const handleLocalChange = (e) => { + setLocalValue(e.target.value); + }; + + const hasActiveFilters = + Object.values(state.filters.booleanFilters).some((value) => value) || + state.filters.precisions.length !== FILTER_PRECISIONS.length || + state.filters.types.length !== MODEL_TYPE_ORDER.length || + state.filters.paramsRange[0] !== -1 || + state.filters.paramsRange[1] !== 140 || + state.filters.isOfficialProviderActive; + + const shouldShowReset = localValue || hasActiveFilters; + + return ( + + + + + {!loading && ( + + )} + + {shouldShowReset && ( + { + setLocalValue(""); + actions.resetFilters(); + }} + sx={{ + display: "flex", + alignItems: "center", + gap: 0.5, + cursor: "pointer", + color: "text.secondary", + backgroundColor: "transparent", + border: "1px solid", + borderColor: "divider", + borderRadius: 1, + padding: "4px 8px", + "&:hover": { + backgroundColor: "action.hover", + color: "text.primary", + }, + userSelect: "none", + transition: "all 0.2s ease", + }} + > + + + Reset + + + )} + + + + Advanced Filters + + + + + + + {localValue ? ( + + ) : ( + + + Supports strict search and regex • Use semicolons for multiple + terms + + + )} + + + ); +}; + +export default SearchBar; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/hooks/useOfficialProvidersMode.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/hooks/useOfficialProvidersMode.js new file mode 100644 index 0000000000000000000000000000000000000000..729129cb3081bb525bcae2fc707f70658f74e778 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/hooks/useOfficialProvidersMode.js @@ -0,0 +1,130 @@ +import { useCallback, useState, useEffect, useRef } from "react"; +import { useSearchParams } from "react-router-dom"; + +const useRouterSearchParams = () => { + try { + return useSearchParams(); + } catch { + return [null, () => {}]; + } +}; + +export const useOfficialProvidersMode = () => { + const [isOfficialProviderActive, setIsOfficialProviderActive] = + useState(false); + const [searchParams, setSearchParams] = useRouterSearchParams(); + const normalFiltersRef = useRef(null); + const isInitialLoadRef = useRef(true); + const lastToggleSourceRef = useRef(null); + + // Effect to handle initial state and updates + useEffect(() => { + if (!searchParams) return; + + const filters = searchParams.get("filters"); + const isHighlighted = + filters?.includes("is_official_provider") || false; + + // On initial load + if (isInitialLoadRef.current) { + isInitialLoadRef.current = false; + + // If official mode is active at start, store filters without the highlightFilter + if (isHighlighted && filters) { + const initialNormalFilters = filters + .split(",") + .filter((f) => f !== "is_official_provider" && f !== "") + .filter(Boolean); + if (initialNormalFilters.length > 0) { + normalFiltersRef.current = initialNormalFilters.join(","); + } + } + + // Update state without triggering URL change + setIsOfficialProviderActive(isHighlighted); + return; + } + + // For subsequent changes + if (!isHighlighted && filters) { + normalFiltersRef.current = filters; + } + + setIsOfficialProviderActive(isHighlighted); + }, [searchParams]); + + const toggleOfficialProviderMode = useCallback( + (source = null) => { + if (!searchParams || !setSearchParams) return; + + // If source is the same as last time and last change was less than 100ms ago, ignore + const now = Date.now(); + if ( + source && + source === lastToggleSourceRef.current?.source && + now - (lastToggleSourceRef.current?.timestamp || 0) < 100 + ) { + return; + } + + const currentFiltersStr = searchParams.get("filters"); + const currentFilters = + currentFiltersStr?.split(",").filter(Boolean) || []; + const highlightFilter = "is_official_provider"; + const newSearchParams = new URLSearchParams(searchParams); + + if (currentFilters.includes(highlightFilter)) { + // Deactivating official provider mode + if (normalFiltersRef.current) { + const normalFilters = normalFiltersRef.current + .split(",") + .filter((f) => f !== highlightFilter && f !== "") + .filter(Boolean); + + if (normalFilters.length > 0) { + newSearchParams.set("filters", normalFilters.join(",")); + } else { + newSearchParams.delete("filters"); + } + } else { + const newFilters = currentFilters.filter( + (f) => f !== highlightFilter && f !== "" + ); + if (newFilters.length === 0) { + newSearchParams.delete("filters"); + } else { + newSearchParams.set("filters", newFilters.join(",")); + } + } + } else { + // Activating official provider mode + if (currentFiltersStr) { + normalFiltersRef.current = currentFiltersStr; + } + + const filtersToSet = [ + ...new Set([...currentFilters, highlightFilter]), + ].filter(Boolean); + newSearchParams.set("filters", filtersToSet.join(",")); + } + + // Update state immediately + setIsOfficialProviderActive(!currentFilters.includes(highlightFilter)); + + // Save source and timestamp of last change + lastToggleSourceRef.current = { + source, + timestamp: now, + }; + + // Update search params and let HashRouter handle the URL + setSearchParams(newSearchParams); + }, + [searchParams, setSearchParams] + ); + + return { + isOfficialProviderActive, + toggleOfficialProviderMode, + }; +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/hooks/usePresets.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/hooks/usePresets.js new file mode 100644 index 0000000000000000000000000000000000000000..35e17e54b0e1978635440908d3de6c742b37a856 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Filters/hooks/usePresets.js @@ -0,0 +1,98 @@ +import { useCallback } from "react"; +import { QUICK_FILTER_PRESETS } from "../../../constants/quickFilters"; +import { TABLE_DEFAULTS } from "../../../constants/defaults"; + +const DEFAULT_FILTERS = { + searchValue: "", + selectedPrecisions: TABLE_DEFAULTS.SEARCH.PRECISIONS, + selectedTypes: TABLE_DEFAULTS.SEARCH.TYPES, + paramsRange: TABLE_DEFAULTS.SEARCH.PARAMS_RANGE, + selectedBooleanFilters: [], +}; + +export const usePresets = (searchFilters) => { + const handlePresetChange = useCallback( + (preset) => { + if (!searchFilters?.batchUpdateState) return; + + if (preset === null) { + // Reset with default values + searchFilters.batchUpdateState(DEFAULT_FILTERS, true); + return; + } + + // Apply preset with default values as base + const updates = { + ...DEFAULT_FILTERS, + ...preset.filters, + }; + + // Apply all changes at once + searchFilters.batchUpdateState(updates, true); + }, + [searchFilters] + ); + + const resetPreset = useCallback(() => { + handlePresetChange(null); + }, [handlePresetChange]); + + const getActivePreset = useCallback(() => { + // If searchFilters is not initialized yet, return null + if (!searchFilters) return null; + + // Dynamic detection of preset matching current filters + const currentParamsRange = Array.isArray(searchFilters.paramsRange) + ? searchFilters.paramsRange + : DEFAULT_FILTERS.paramsRange; + const currentBooleanFilters = Array.isArray( + searchFilters.selectedBooleanFilters + ) + ? searchFilters.selectedBooleanFilters + : DEFAULT_FILTERS.selectedBooleanFilters; + const currentPrecisions = Array.isArray(searchFilters.selectedPrecisions) + ? searchFilters.selectedPrecisions + : DEFAULT_FILTERS.selectedPrecisions; + const currentTypes = Array.isArray(searchFilters.selectedTypes) + ? searchFilters.selectedTypes + : DEFAULT_FILTERS.selectedTypes; + + return ( + QUICK_FILTER_PRESETS.find((preset) => { + const presetParamsRange = Array.isArray(preset.filters.paramsRange) + ? preset.filters.paramsRange + : DEFAULT_FILTERS.paramsRange; + const presetBooleanFilters = Array.isArray( + preset.filters.selectedBooleanFilters + ) + ? preset.filters.selectedBooleanFilters + : DEFAULT_FILTERS.selectedBooleanFilters; + + const paramsMatch = + JSON.stringify(presetParamsRange) === + JSON.stringify(currentParamsRange); + const booleanFiltersMatch = + JSON.stringify(presetBooleanFilters.sort()) === + JSON.stringify(currentBooleanFilters.sort()); + + // Check if other filters match default values + const precisionMatch = + JSON.stringify(currentPrecisions.sort()) === + JSON.stringify(DEFAULT_FILTERS.selectedPrecisions.sort()); + const typesMatch = + JSON.stringify(currentTypes.sort()) === + JSON.stringify(DEFAULT_FILTERS.selectedTypes.sort()); + + return ( + paramsMatch && booleanFiltersMatch && precisionMatch && typesMatch + ); + })?.id || null + ); + }, [searchFilters]); + + return { + activePreset: getActivePreset(), + handlePresetChange, + resetPreset, + }; +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/PerformanceMonitor.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/PerformanceMonitor.js new file mode 100644 index 0000000000000000000000000000000000000000..d3a20d28639f0d84835d854fe405795e14499d01 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/PerformanceMonitor.js @@ -0,0 +1,570 @@ +import React, { useEffect, useState, useRef } from "react"; +import { Box, Typography, Tooltip, useTheme } from "@mui/material"; +import NetworkCheckIcon from "@mui/icons-material/NetworkCheck"; +import MemoryIcon from "@mui/icons-material/Memory"; +import SpeedIcon from "@mui/icons-material/Speed"; +import GpuIcon from "@mui/icons-material/Memory"; +import InfoOutlinedIcon from "@mui/icons-material/InfoOutlined"; + +const getGPUStats = () => { + try { + const canvas = document.createElement("canvas"); + const gl = + canvas.getContext("webgl") || canvas.getContext("experimental-webgl"); + + if (!gl) { + canvas.remove(); + return null; + } + + // Try to get GPU info extensions + const debugInfo = gl.getExtension("WEBGL_debug_renderer_info"); + + // Estimate GPU memory usage (very approximate) + let usedMemoryEstimate = 0; + + try { + // Create test texture + const testTexture = gl.createTexture(); + gl.bindTexture(gl.TEXTURE_2D, testTexture); + + // Test size: 1024x1024 RGBA + const testSize = 1024; + const pixels = new Uint8Array(testSize * testSize * 4); + gl.texImage2D( + gl.TEXTURE_2D, + 0, + gl.RGBA, + testSize, + testSize, + 0, + gl.RGBA, + gl.UNSIGNED_BYTE, + pixels + ); + + // Estimate memory usage (very approximate) + usedMemoryEstimate = (testSize * testSize * 4) / (1024 * 1024); // In MB + + gl.deleteTexture(testTexture); + gl.getExtension("WEBGL_lose_context")?.loseContext(); + } catch (e) { + console.warn("GPU memory estimation failed:", e); + } finally { + // Cleanup WebGL resources + const loseContext = gl.getExtension("WEBGL_lose_context"); + if (loseContext) loseContext.loseContext(); + gl.canvas.remove(); + } + + return { + vendor: debugInfo + ? gl.getParameter(debugInfo.UNMASKED_VENDOR_WEBGL) + : "Unknown", + renderer: debugInfo + ? gl.getParameter(debugInfo.UNMASKED_RENDERER_WEBGL) + : "Unknown", + usedMemory: Math.round(usedMemoryEstimate), + }; + } catch (e) { + return null; + } +}; + +const MetricBox = ({ icon, label, value, tooltip }) => { + const theme = useTheme(); + return ( + + {icon} + + + {label} + + + {React.isValidElement(value) ? value : {value}} + + {tooltip && ( + + + + + + )} + + ); +}; + +const formatNumber = (num) => { + return num.toString().replace(/\B(?=(\d{3})+(?!\d))/g, " "); +}; + +const PerformanceMonitor = () => { + const theme = useTheme(); + + const [stats, setStats] = useState({ + fps: 0, + memory: { + usedJSHeapSize: 0, + totalJSHeapSize: 0, + }, + renders: 0, + network: { + transferSize: 0, + decodedBodySize: 0, + compressionRatio: 0, + }, + gpu: getGPUStats(), + fcp: null, + }); + const [isVisible, setIsVisible] = useState( + process.env.NODE_ENV === "development" + ); + const renderCountRef = useRef(0); + const originalCreateElementRef = useRef(null); + + useEffect(() => { + const handleKeyDown = (event) => { + // Ignore if user is in an input field + if ( + event.target.tagName === "INPUT" || + event.target.tagName === "TEXTAREA" + ) { + return; + } + + if (event.key === "p" || event.key === "P") { + setIsVisible((prev) => !prev); + } + }; + + window.addEventListener("keydown", handleKeyDown); + return () => window.removeEventListener("keydown", handleKeyDown); + }, []); + + useEffect(() => { + let frameCount = 0; + let lastTime = performance.now(); + let animationFrameId; + + const getNetworkStats = () => { + const resources = performance.getEntriesByType("resource"); + const navigation = performance.getEntriesByType("navigation")[0]; + + let totalTransferSize = navigation ? navigation.transferSize : 0; + let totalDecodedSize = navigation ? navigation.decodedBodySize : 0; + + resources.forEach((resource) => { + totalTransferSize += resource.transferSize || 0; + totalDecodedSize += resource.decodedBodySize || 0; + }); + + const compressionRatio = totalDecodedSize + ? Math.round((1 - totalTransferSize / totalDecodedSize) * 100) + : 0; + + return { + transferSize: Math.round(totalTransferSize / 1024), + decodedBodySize: Math.round(totalDecodedSize / 1024), + compressionRatio, + }; + }; + + // Save original function + originalCreateElementRef.current = React.createElement; + + // Replace createElement + React.createElement = function (...args) { + renderCountRef.current++; + return originalCreateElementRef.current.apply(this, args); + }; + + const updateStats = () => { + frameCount++; + const now = performance.now(); + const delta = now - lastTime; + + if (delta >= 1000) { + const fps = Math.round((frameCount * 1000) / delta); + + const memory = window.performance?.memory + ? { + usedJSHeapSize: Math.round( + window.performance.memory.usedJSHeapSize / 1048576 + ), + totalJSHeapSize: Math.round( + window.performance.memory.totalJSHeapSize / 1048576 + ), + } + : null; + + const network = getNetworkStats(); + const gpu = getGPUStats(); + + setStats((prev) => ({ + ...prev, + fps, + memory: memory || prev.memory, + renders: renderCountRef.current, + network, + gpu, + })); + + frameCount = 0; + lastTime = now; + } + + animationFrameId = requestAnimationFrame(updateStats); + }; + + updateStats(); + + return () => { + cancelAnimationFrame(animationFrameId); + // Restore original function + if (originalCreateElementRef.current) { + React.createElement = originalCreateElementRef.current; + } + // Clean up counters + renderCountRef.current = 0; + delete window.__REACT_RENDERS__; + }; + }, []); + + useEffect(() => { + // Add FCP observer + if (window.PerformanceObserver) { + try { + const fcpObserver = new PerformanceObserver((entryList) => { + const entries = entryList.getEntries(); + if (entries.length > 0) { + const fcp = entries[0].startTime; + setStats((prev) => ({ + ...prev, + fcp, + })); + } + }); + + fcpObserver.observe({ entryTypes: ["paint"] }); + return () => fcpObserver.disconnect(); + } catch (e) { + console.warn("FCP observation failed:", e); + } + } + }, []); + + const getFpsColor = (fps) => { + if (fps >= 55) return "#4CAF50"; + if (fps >= 30) return "#FFC107"; + return "#F44336"; + }; + + return isVisible ? ( + + + + Performances{" "} + dev only + + + {/* Performance Metrics */} + + + } + label="FPS" + value={ + + {stats.fps} + + } + tooltip="Frames Per Second - Indicates how smooth the UI is running" + /> + + {stats.fcp !== null && ( + + } + label="FCP" + value={ + + {Math.round(stats.fcp)}ms + + } + tooltip="First Contentful Paint - Time until first content is rendered" + /> + )} + + ⚛️ + + } + label="React" + value={ + + {formatNumber(stats.renders)} + cycles + + } + tooltip="Total number of React render cycles" + /> + + + {/* Memory Metrics */} + + {window.performance?.memory && ( + } + label="Mem" + value={ + + {stats.memory.usedJSHeapSize} + / + {stats.memory.totalJSHeapSize} + MB + + } + tooltip="JavaScript heap memory usage (Used / Total)" + /> + )} + {stats.gpu && ( + } + label="GPU" + value={ + + {stats.gpu.usedMemory} + MB + + } + tooltip="Estimated GPU memory usage" + /> + )} + + + {/* Network Metrics */} + + + } + label="Net" + value={ + + {stats.network.transferSize} + KB + + } + tooltip="Network data transferred" + /> + } + label="Size" + value={ + + {formatNumber(stats.network.decodedBodySize)} + KB + 0 ? "#81C784" : "inherit", + fontSize: "0.7rem", + opacity: 0.8, + ml: 1, + }} + > + (-{stats.network.compressionRatio}%) + + + } + tooltip="Total decoded size and compression ratio" + /> + + + Press "P" to show/hide + + + +
    + ) : null; +}; + +export default React.memo(PerformanceMonitor); diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/Table.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/Table.js new file mode 100644 index 0000000000000000000000000000000000000000..b9279247881135a2d4cf2122ed542474fc20f6be --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/Table.js @@ -0,0 +1,720 @@ +import React, { useRef, useCallback, useMemo } from "react"; +import { + Paper, + Table, + TableContainer, + TableHead, + TableBody, + TableRow, + TableCell, + Box, + Typography, + Skeleton, +} from "@mui/material"; +import { flexRender } from "@tanstack/react-table"; +import { useVirtualizer } from "@tanstack/react-virtual"; +import KeyboardArrowUpIcon from "@mui/icons-material/KeyboardArrowUp"; +import KeyboardArrowDownIcon from "@mui/icons-material/KeyboardArrowDown"; +import UnfoldMoreIcon from "@mui/icons-material/UnfoldMore"; +import SearchOffIcon from "@mui/icons-material/SearchOff"; +import { + TABLE_DEFAULTS, + ROW_SIZES, + SKELETON_COLUMNS, +} from "../../constants/defaults"; +import { alpha } from "@mui/material/styles"; +import TableOptions from "../DisplayOptions/DisplayOptions"; +import ColumnSelector from "../ColumnSelector/ColumnSelector"; + +const NoResultsFound = () => ( + + + + No models found + + + Try modifying your filters or search to see more models. + + +); + +const TableSkeleton = ({ rowSize = "normal" }) => { + const currentRowHeight = Math.floor(ROW_SIZES[rowSize]); + const headerHeight = Math.floor(currentRowHeight * 1.25); + const skeletonRows = 10; + + return ( + + `1px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.05 : 0.1 + )}`, + borderRadius: 1, + }} + > + + + + {SKELETON_COLUMNS.map((width, index) => ( + 3 ? "right" : "left", + borderRight: (theme) => `1px solid ${theme.palette.divider}`, + "&:last-child": { + borderRight: "none", + }, + position: "sticky", + top: 0, + backgroundColor: (theme) => theme.palette.background.paper, + zIndex: 2, + }} + /> + ))} + + + + {[...Array(skeletonRows)].map((_, index) => ( + + index % 2 === 0 ? "transparent" : theme.palette.action.hover, + }} + > + {SKELETON_COLUMNS.map((width, cellIndex) => ( + + `1px solid ${theme.palette.divider}`, + "&:last-child": { + borderRight: "none", + }, + }} + > + 3 ? "auto" : 0, + backgroundColor: (theme) => + alpha(theme.palette.text.primary, 0.11), + "&::after": { + background: (theme) => + `linear-gradient(90deg, ${alpha( + theme.palette.text.primary, + 0.11 + )}, ${alpha( + theme.palette.text.primary, + 0.14 + )}, ${alpha(theme.palette.text.primary, 0.11)})`, + }, + }} + /> + + ))} + + ))} + +
    +
    + ); +}; + +const TableControls = React.memo( + ({ + loading, + rowSize, + onRowSizeChange, + scoreDisplay, + onScoreDisplayChange, + averageMode, + onAverageModeChange, + rankingMode, + onRankingModeChange, + hasTableOptionsChanges, + searchParams, + setSearchParams, + table, + handleColumnReset, + hasColumnFilterChanges, + onColumnVisibilityChange, + }) => ( + + + + + ) +); + +TableControls.displayName = "TableControls"; + +const LeaderboardTable = ({ + table, + rowSize = "normal", + loading = false, + hasTableOptionsChanges, + hasColumnFilterChanges, + onColumnVisibilityChange, + scoreDisplay, + onScoreDisplayChange, + averageMode, + onAverageModeChange, + rankingMode, + onRankingModeChange, + onRowSizeChange, + searchParams, + setSearchParams, + pinnedModels = [], +}) => { + const { rows } = table.getRowModel(); + const parentRef = useRef(null); + + const currentRowHeight = useMemo(() => ROW_SIZES[rowSize], [rowSize]); + const headerHeight = useMemo( + () => Math.floor(currentRowHeight * 1.25), + [currentRowHeight] + ); + + // Separate pinned rows from normal rows while preserving original order + const pinnedRows = useMemo(() => { + const pinnedModelRows = rows.filter((row) => row.original.isPinned); + // Sort pinned models according to their original order in pinnedModels + return pinnedModelRows.sort((a, b) => { + const aIndex = pinnedModels.indexOf(a.original.id); + const bIndex = pinnedModels.indexOf(b.original.id); + return aIndex - bIndex; + }); + }, [rows, pinnedModels]); + + const unpinnedRows = useMemo( + () => rows.filter((row) => !row.original.isPinned), + [rows] + ); + const pinnedHeight = useMemo( + () => pinnedRows.length * currentRowHeight, + [pinnedRows.length, currentRowHeight] + ); + + const virtualizerOptions = useMemo( + () => ({ + count: unpinnedRows.length, + getScrollElement: () => parentRef.current, + estimateSize: () => currentRowHeight, + overscan: 15, + scrollMode: "sync", + scrollPaddingStart: pinnedHeight, + scrollPaddingEnd: 0, + initialRect: { width: 0, height: currentRowHeight * 15 }, + }), + [currentRowHeight, unpinnedRows.length, pinnedHeight] + ); + + const rowVirtualizer = useVirtualizer(virtualizerOptions); + + const virtualRows = rowVirtualizer.getVirtualItems(); + + // Adjust paddings to account for pinned rows + const paddingTop = virtualRows.length > 0 ? virtualRows[0].start : 0; + const paddingBottom = + virtualRows.length > 0 + ? unpinnedRows.length * currentRowHeight - + virtualRows[virtualRows.length - 1].end + : 0; + + // Handle column reset + const handleColumnReset = useCallback(() => { + onColumnVisibilityChange(TABLE_DEFAULTS.COLUMNS.DEFAULT_VISIBLE); + }, [onColumnVisibilityChange]); + + const cellStyles = (theme) => ({ + borderRight: `1px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.05 : 0.1 + )}`, + "&:last-child": { + borderRight: "none", + }, + whiteSpace: "nowrap", + overflow: "hidden", + textOverflow: "ellipsis", + padding: "8px 16px", + }); + + const headerCellStyles = (theme) => ({ + ...cellStyles(theme), + padding: "6px 16px", + height: "36px", + position: "sticky !important", + top: 0, + zIndex: 10, + "& > .header-content": { + display: "flex", + alignItems: "center", + width: "100%", + gap: "4px", + flexDirection: "row", + }, + }); + + const getSortingIcon = (column) => { + if ( + column.id === "rank" || + column.id === "model_type" || + column.id === "isPinned" + ) { + return null; + } + + if (!column.getIsSorted()) { + return ; + } + return column.getIsSorted() === "desc" ? ( + + ) : ( + + ); + }; + + const renderHeaderContent = (header) => { + const sortIcon = getSortingIcon(header.column); + return ( + + {flexRender(header.column.columnDef.header, header.getContext())} + + {sortIcon || } + + + ); + }; + + const renderRow = (row, isSticky = false, stickyIndex = 0) => { + // Get row index in the sorted data model + const sortedIndex = table + .getSortedRowModel() + .rows.findIndex((r) => r.id === row.id); + + return ( + ({ + height: `${currentRowHeight}px !important`, + backgroundColor: isSticky + ? theme.palette.background.paper + : (sortedIndex + 1) % 2 === 0 + ? "transparent" + : alpha(theme.palette.mode === "dark" ? "#fff" : "#000", 0.02), + position: isSticky ? "sticky" : "relative", + top: isSticky + ? `${headerHeight + stickyIndex * currentRowHeight}px` + : "auto", + zIndex: isSticky ? 2 : 1, + boxShadow: isSticky + ? `0 1px 1px ${alpha( + theme.palette.common.black, + theme.palette.mode === "dark" ? 0.1 : 0.05 + )}` + : "none", + "&::after": isSticky + ? { + content: '""', + position: "absolute", + left: 0, + right: 0, + height: "1px", + bottom: -1, + backgroundColor: alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.1 : 0.2 + ), + zIndex: 1, + } + : {}, + })} + > + {row.getVisibleCells().map((cell) => ( + ({ + width: `${cell.column.columnDef.size}px !important`, + minWidth: `${cell.column.columnDef.size}px !important`, + height: `${currentRowHeight}px`, + backgroundColor: isSticky + ? theme.palette.background.paper + : "inherit", + borderBottom: isSticky + ? "none" + : `1px solid ${theme.palette.divider}`, + ...cellStyles(theme), + ...(cell.column.columnDef.meta?.cellStyle?.(cell.getValue()) || + {}), + "& .MuiBox-root": { + overflow: "visible", + }, + })} + > + {flexRender(cell.column.columnDef.cell, cell.getContext())} + + ))} + + ); + }; + + if (!loading && (!rows || rows.length === 0)) { + return ( + + + + + + + ); + } + + if (loading) { + return ( + + + + + + + ); + } + + return ( + + + + ({ + height: "100%", + overflow: "auto", + border: "none", + boxShadow: "none", + "&::-webkit-scrollbar": { + width: "8px", + height: "8px", + }, + "&::-webkit-scrollbar-thumb": { + backgroundColor: alpha( + theme.palette.common.black, + theme.palette.mode === "dark" ? 0.4 : 0.2 + ), + borderRadius: "4px", + }, + "&::-webkit-scrollbar-corner": { + backgroundColor: theme.palette.background.paper, + }, + willChange: "transform", + transform: "translateZ(0)", + WebkitOverflowScrolling: "touch", + scrollBehavior: "auto", + })} + > + 0 ? "fixed" : "fixed", + border: "none", + "& td, & th": + pinnedRows.length > 0 + ? { + width: `${100 / table.getAllColumns().length}%`, + } + : {}, + }} + > + + {table.getAllColumns().map((column, index) => ( + + ))} + + + theme.palette.background.paper, + "& th": { + backgroundColor: (theme) => theme.palette.background.paper, + }, + }} + > + {table.getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => ( + ({ + cursor: header.column.getCanSort() + ? "pointer" + : "default", + width: header.column.columnDef.size, + minWidth: header.column.columnDef.size, + ...headerCellStyles(theme), + textAlign: "left", + fontWeight: header.column.getIsSorted() ? 700 : 400, + userSelect: "none", + height: `${headerHeight}px`, + padding: `${headerHeight * 0.25}px 16px`, + backgroundColor: theme.palette.background.paper, + })} + > + {renderHeaderContent(header)} + + ))} + + ))} + + + + {/* Pinned rows */} + {pinnedRows.map((row, index) => renderRow(row, true, index))} + + {/* Padding for virtualized rows */} + {paddingTop > 0 && ( + + + + )} + + {/* Virtualized unpinned rows */} + {virtualRows.map((virtualRow) => { + const row = unpinnedRows[virtualRow.index]; + if (!row) return null; + return renderRow(row); + })} + + {/* Bottom padding */} + {paddingBottom > 0 && ( + + + + )} + +
    +
    +
    +
    + ); +}; + +export { TableSkeleton }; +export default LeaderboardTable; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/hooks/useDataProcessing.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/hooks/useDataProcessing.js new file mode 100644 index 0000000000000000000000000000000000000000..6f5463755578ae260d6639403706e5f6071eb614 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/hooks/useDataProcessing.js @@ -0,0 +1,161 @@ +import { useMemo } from "react"; +import { + useReactTable, + getSortedRowModel, + getCoreRowModel, + getFilteredRowModel, +} from "@tanstack/react-table"; +import { createColumns } from "../../../utils/columnUtils"; +import { + useAverageRange, + useColorGenerator, + useProcessedData, + useFilteredData, + useColumnVisibility, +} from "../../../hooks/useDataUtils"; + +export const useDataProcessing = ( + data, + searchValue, + selectedPrecisions, + selectedTypes, + paramsRange, + selectedBooleanFilters, + sorting, + rankingMode, + averageMode, + visibleColumns, + scoreDisplay, + pinnedModels, + onTogglePin, + setSorting, + isOfficialProviderActive +) => { + // Call hooks directly at root level + const { minAverage, maxAverage } = useAverageRange(data); + const getColorForValue = useColorGenerator(minAverage, maxAverage); + const processedData = useProcessedData(data, averageMode, visibleColumns); + const columnVisibility = useColumnVisibility(visibleColumns); + + // Memoize filters + const filterConfig = useMemo( + () => ({ + selectedPrecisions, + selectedTypes, + paramsRange, + searchValue, + selectedBooleanFilters, + rankingMode, + pinnedModels, + isOfficialProviderActive, + }), + [ + selectedPrecisions, + selectedTypes, + paramsRange, + searchValue, + selectedBooleanFilters, + rankingMode, + pinnedModels, + isOfficialProviderActive, + ] + ); + + // Call useFilteredData at root level + const filteredData = useFilteredData( + processedData, + filterConfig.selectedPrecisions, + filterConfig.selectedTypes, + filterConfig.paramsRange, + filterConfig.searchValue, + filterConfig.selectedBooleanFilters, + filterConfig.rankingMode, + filterConfig.pinnedModels, + filterConfig.isOfficialProviderActive + ); + + // Memoize columns creation + const columns = useMemo( + () => + createColumns( + getColorForValue, + scoreDisplay, + columnVisibility, + data.length, + averageMode, + searchValue, + rankingMode, + onTogglePin + ), + [ + getColorForValue, + scoreDisplay, + columnVisibility, + data.length, + averageMode, + searchValue, + rankingMode, + onTogglePin, + ] + ); + + // Memoize table configuration + const tableConfig = useMemo( + () => ({ + data: filteredData, + columns, + state: { + sorting: Array.isArray(sorting) ? sorting : [], + columnVisibility, + }, + getCoreRowModel: getCoreRowModel(), + getFilteredRowModel: getFilteredRowModel(), + getSortedRowModel: getSortedRowModel(), + onSortingChange: setSorting, + enableColumnVisibility: true, + defaultColumn: { + sortingFn: (rowA, rowB, columnId) => { + const isDesc = sorting?.[0]?.desc; + + if (rowA.original.isPinned && rowB.original.isPinned) { + return ( + pinnedModels.indexOf(rowA.original.id) - + pinnedModels.indexOf(rowB.original.id) + ); + } + + if (isDesc) { + if (rowA.original.isPinned) return -1; + if (rowB.original.isPinned) return 1; + } else { + if (rowA.original.isPinned) return -1; + if (rowB.original.isPinned) return 1; + } + + const aValue = rowA.getValue(columnId); + const bValue = rowB.getValue(columnId); + + if (typeof aValue === "number" && typeof bValue === "number") { + return aValue - bValue; + } + + return String(aValue).localeCompare(String(bValue)); + }, + }, + }), + [filteredData, columns, sorting, columnVisibility, pinnedModels, setSorting] + ); + + const table = useReactTable(tableConfig); + + return { + table, + minAverage, + maxAverage, + getColorForValue, + processedData, + filteredData, + columns, + columnVisibility, + }; +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/hooks/useSorting.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/hooks/useSorting.js new file mode 100644 index 0000000000000000000000000000000000000000..b6e24b528b4938ecd52e2a61624e028d3ffc8dc0 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/Table/hooks/useSorting.js @@ -0,0 +1,16 @@ +export const typeColumnSort = (rowA, rowB) => { + const aValue = rowA.getValue("model_type"); + const bValue = rowB.getValue("model_type"); + + // If both values are arrays, compare their first elements + if (Array.isArray(aValue) && Array.isArray(bValue)) { + return String(aValue[0] || "").localeCompare(String(bValue[0] || "")); + } + + // If one is array and other isn't, array comes first + if (Array.isArray(aValue)) return -1; + if (Array.isArray(bValue)) return 1; + + // If neither is array, compare as strings + return String(aValue || "").localeCompare(String(bValue || "")); +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/shared/DropdownButton.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/shared/DropdownButton.js new file mode 100644 index 0000000000000000000000000000000000000000..2badebd0fb115b1a0f78ff81abd41c2b384c9233 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/components/shared/DropdownButton.js @@ -0,0 +1,137 @@ +import React, { useState } from "react"; +import { Box, Popover, Portal, Typography, Skeleton } from "@mui/material"; +import { useTheme } from "@mui/material/styles"; +import { commonStyles } from "../../styles/common"; + +const DropdownButton = ({ + label, + icon: Icon, + closeIcon: CloseIcon, + hasChanges = false, + children, + defaultWidth = 340, + paperProps = {}, + buttonSx = {}, + loading = false, +}) => { + const theme = useTheme(); + const [anchorEl, setAnchorEl] = useState(null); + + const handleClick = (event) => { + event.stopPropagation(); + setAnchorEl(event.currentTarget); + }; + + const handleClose = (event) => { + if (event) { + event.stopPropagation(); + } + setAnchorEl(null); + }; + + if (loading) { + return ( + + ); + } + + return ( + + + {Boolean(anchorEl) && CloseIcon ? ( + + ) : ( + + )} + + {label} + + + + + theme.palette.mode === "light" + ? "rgba(0, 0, 0, 0.12)" + : "rgba(255, 255, 255, 0.12)", + borderRadius: 1, + position: "relative", + boxShadow: (theme) => + `0px 4px 20px ${ + theme.palette.mode === "light" + ? "rgba(0, 0, 0, 0.1)" + : "rgba(255, 255, 255, 0.1)" + }`, + ...paperProps.sx, + }, + ...paperProps, + }} + anchorOrigin={{ + vertical: "bottom", + horizontal: "right", + }} + transformOrigin={{ + vertical: "top", + horizontal: "right", + }} + slotProps={{ + backdrop: { + sx: { + backgroundColor: "transparent", + }, + }, + }} + > + {children} + + + + ); +}; + +export default DropdownButton; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/defaults.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/defaults.js new file mode 100644 index 0000000000000000000000000000000000000000..c99bb1274a3d2dbb8eb74d95820b54a0c09dbbe1 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/defaults.js @@ -0,0 +1,289 @@ +import { MODEL_TYPE_ORDER } from "./modelTypes"; + +// Time constants (in milliseconds) +const TIME = { + CACHE_DURATION: 5 * 60 * 1000, // 5 minutes + DEBOUNCE: { + URL_PARAMS: 100, + SEARCH: 150, + RANGE_PICKER: 350, + }, +}; + +// Display constants +const DISPLAY = { + ROW_SIZES: { + normal: 45, + large: 60, + }, + SCORE_DISPLAY_OPTIONS: [ + { value: "normalized", label: "Normalized" }, + { value: "raw", label: "Raw" }, + ], + RANKING_MODE_OPTIONS: [ + { value: "static", label: "Static" }, + { value: "dynamic", label: "Dynamic" }, + ], +}; + +// Filter constants +const FILTERS = { + PRECISIONS: ["bfloat16", "float16", "4bit"], + SUBMISSION_PRECISIONS: [ + { value: "float16", label: "float16" }, + { value: "bfloat16", label: "bfloat16" }, + { value: "8bit", label: "8-bit" }, + { value: "4bit", label: "4-bit" }, + { value: "gptq", label: "GPTQ" }, + ], + PARAMS_RANGE: [-1, 500], + BOOLEAN_OPTIONS: [ + { + value: "is_not_available_on_hub", + label: "Unavailable model", + hide: true, + }, + ], +}; + +// Column size constants +const COLUMN_SIZES = { + RANK: 65, + TYPE_ICON: 65, + MODEL: 400, + AVERAGE_SCORE: 150, + BENCHMARK: 110, + HUB_HEARTS: 140, + ARCHITECTURE: 210, + PRECISION: 140, + PARAMS: 160, + LICENSE: 160, + UPLOAD_DATE: 160, + SUBMISSION_DATE: 200, + GENERATION: 160, + BASE_MODEL: 390, + HUB_AVAILABILITY: 180, +}; + +// Column definitions with organized structure +const COLUMNS = { + FIXED: { + rank: { + group: "fixed", + size: COLUMN_SIZES.RANK, + defaultVisible: true, + label: "Rank", + }, + "model.type_icon": { + group: "fixed", + size: COLUMN_SIZES.TYPE_ICON, + defaultVisible: true, + label: "Type", + }, + id: { + group: "fixed", + size: COLUMN_SIZES.MODEL, + defaultVisible: true, + label: "Model", + }, + "model.average_score": { + group: "fixed", + size: COLUMN_SIZES.AVERAGE_SCORE, + defaultVisible: true, + label: "Average Score", + }, + }, + EVALUATION: { + "evaluations.anli.normalized_score": { + group: "evaluation", + size: COLUMN_SIZES.BENCHMARK, + defaultVisible: true, + label: "ANLI", + }, + "evaluations.logiqa.normalized_score": { + group: "evaluation", + size: COLUMN_SIZES.BENCHMARK, + defaultVisible: true, + label: "LogiQA", + }, + }, + MODEL_INFO: { + "metadata.hub_hearts": { + group: "model_info", + size: COLUMN_SIZES.HUB_HEARTS, + defaultVisible: false, + label: "Hub \u2764\uFE0F", + }, + "model.architecture": { + group: "model_info", + size: COLUMN_SIZES.ARCHITECTURE, + defaultVisible: false, + label: "Architecture", + }, + "model.precision": { + group: "model_info", + size: COLUMN_SIZES.PRECISION, + defaultVisible: false, + label: "Precision", + }, + "metadata.params_millions": { + group: "model_info", + size: COLUMN_SIZES.PARAMS, + defaultVisible: false, + label: "Parameters (M)", + }, + "metadata.hub_license": { + group: "model_info", + size: COLUMN_SIZES.LICENSE, + defaultVisible: false, + label: "License", + }, + }, + ADDITIONAL_INFO: { + "metadata.upload_date": { + group: "additional_info", + size: COLUMN_SIZES.UPLOAD_DATE, + defaultVisible: false, + label: "Upload Date", + }, + "metadata.submission_date": { + group: "additional_info", + size: COLUMN_SIZES.SUBMISSION_DATE, + defaultVisible: false, + label: "Submission Date", + }, + "metadata.generation": { + group: "additional_info", + size: COLUMN_SIZES.GENERATION, + defaultVisible: false, + label: "Generation", + }, + "metadata.base_model": { + group: "additional_info", + size: COLUMN_SIZES.BASE_MODEL, + defaultVisible: false, + label: "Base Model", + }, + "features.is_not_available_on_hub": { + group: "additional_info", + size: COLUMN_SIZES.HUB_AVAILABILITY, + defaultVisible: false, + label: "Hub Availability", + }, + }, +}; + +// Combine all columns for backward compatibility +const ALL_COLUMNS = { + ...COLUMNS.FIXED, + ...COLUMNS.EVALUATION, + ...COLUMNS.MODEL_INFO, + ...COLUMNS.ADDITIONAL_INFO, +}; + +// Column definitions for external use (maintaining the same interface) +const COLUMN_DEFINITIONS = { + ALL_COLUMNS, + COLUMN_GROUPS: { + "Evaluation Scores": Object.keys(COLUMNS.EVALUATION), + "Model Information": Object.keys(COLUMNS.MODEL_INFO), + "Additional Information": Object.keys(COLUMNS.ADDITIONAL_INFO), + }, + COLUMN_LABELS: Object.entries(ALL_COLUMNS).reduce((acc, [key, value]) => { + acc[key] = value.label; + return acc; + }, {}), + DEFAULT_VISIBLE: Object.entries(ALL_COLUMNS) + .filter(([_, value]) => value.defaultVisible) + .map(([key]) => key), + + get FIXED() { + return Object.entries(ALL_COLUMNS) + .filter(([_, def]) => def.group === "fixed") + .map(([key]) => key); + }, + + get EVALUATION() { + return Object.entries(ALL_COLUMNS) + .filter(([_, def]) => def.group === "evaluation") + .map(([key]) => key); + }, + + get OPTIONAL() { + return Object.entries(ALL_COLUMNS) + .filter(([_, def]) => def.group !== "fixed" && def.group !== "evaluation") + .map(([key]) => key); + }, + + get COLUMN_SIZES() { + return Object.entries(ALL_COLUMNS).reduce( + (acc, [key, def]) => ({ + ...acc, + [key]: def.size, + }), + {} + ); + }, +}; + +// Export constants maintaining the same interface +export const FILTER_PRECISIONS = FILTERS.PRECISIONS; +export const SUBMISSION_PRECISIONS = FILTERS.SUBMISSION_PRECISIONS; +export const PARAMS_RANGE = FILTERS.PARAMS_RANGE; +export const CACHE_SETTINGS = { DURATION: TIME.CACHE_DURATION }; +export const PINNED_MODELS = []; +export const DEBOUNCE_TIMINGS = TIME.DEBOUNCE; +export const ROW_SIZES = DISPLAY.ROW_SIZES; +export const SCORE_DISPLAY_OPTIONS = DISPLAY.SCORE_DISPLAY_OPTIONS; +export const RANKING_MODE_OPTIONS = DISPLAY.RANKING_MODE_OPTIONS; +export const BOOLEAN_FILTER_OPTIONS = FILTERS.BOOLEAN_OPTIONS; +export { COLUMN_DEFINITIONS }; + +// Export defaults for backward compatibility +export const TABLE_DEFAULTS = { + ROW_SIZE: "normal", + SCORE_DISPLAY: "normalized", + AVERAGE_MODE: "all", + RANKING_MODE: "static", + SEARCH: { + PRECISIONS: FILTERS.PRECISIONS, + TYPES: MODEL_TYPE_ORDER, + PARAMS_RANGE: FILTERS.PARAMS_RANGE, + }, + DEFAULT_SELECTED: { + searchValue: "", + selectedPrecisions: FILTERS.PRECISIONS, + selectedTypes: MODEL_TYPE_ORDER, + paramsRange: FILTERS.PARAMS_RANGE, + selectedBooleanFilters: [], + }, + DEBOUNCE: TIME.DEBOUNCE, + COLUMNS: COLUMN_DEFINITIONS, + PINNED_MODELS: [], + CACHE_DURATION: TIME.CACHE_DURATION, +}; + +// Highlight colors for search and table +export const HIGHLIGHT_COLORS = [ + "#1f77b4", // blue + "#ff7f0e", // orange + "#2ca02c", // green + "#d62728", // red + "#9467bd", // violet + "#8c564b", // brown + "#e377c2", // pink + "#7f7f7f", // grey + "#bcbd22", // olive + "#17becf", // cyan +]; + +// Skeleton columns widths (in pixels) +export const SKELETON_COLUMNS = [ + 40, // Checkbox + COLUMN_SIZES.RANK, // Rank + COLUMN_SIZES.TYPE_ICON, // Type icon + COLUMN_SIZES.MODEL, // Model name + COLUMN_SIZES.AVERAGE_SCORE, // Average score + COLUMN_SIZES.BENCHMARK, // Benchmark 1 (ANLI) + COLUMN_SIZES.BENCHMARK, // Benchmark 2 (LogiQA) +]; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/modelTypes.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/modelTypes.js new file mode 100644 index 0000000000000000000000000000000000000000..9f5059ae62f8e5385955090602d5a360bb7d37ff --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/modelTypes.js @@ -0,0 +1,65 @@ +export const MODEL_TYPE_ORDER = [ + 'pretrained', + 'fine-tuned', + 'task-specific', + 'foundation' +]; + +export const MODEL_TYPES = { + 'pretrained': { + icon: '\u{1F7E2}', + label: 'Pretrained', + description: 'Base EEG models trained with self-supervised learning', + order: 0 + }, + 'fine-tuned': { + icon: '\u{1F536}', + label: 'Fine-tuned', + description: 'Models fine-tuned on specific EEG datasets', + order: 1 + }, + 'task-specific': { + icon: '\u{1F9EA}', + label: 'Task-specific', + description: 'Models designed for specific EEG tasks (e.g., sleep staging, motor imagery)', + order: 2 + }, + 'foundation': { + icon: '\u{1F3D7}\u{FE0F}', + label: 'Foundation', + description: 'Large-scale EEG foundation models', + order: 3 + } +}; + +export const getModelTypeIcon = (type) => { + const cleanType = type.toLowerCase().trim(); + const matchedType = Object.entries(MODEL_TYPES).find(([key]) => + cleanType.includes(key) + ); + return matchedType ? matchedType[1].icon : '\u{2753}'; +}; + +export const getModelTypeLabel = (type) => { + const cleanType = type.toLowerCase().trim(); + const matchedType = Object.entries(MODEL_TYPES).find(([key]) => + cleanType.includes(key) + ); + return matchedType ? matchedType[1].label : type; +}; + +export const getModelTypeDescription = (type) => { + const cleanType = type.toLowerCase().trim(); + const matchedType = Object.entries(MODEL_TYPES).find(([key]) => + cleanType.includes(key) + ); + return matchedType ? matchedType[1].description : 'Unknown model type'; +}; + +export const getModelTypeOrder = (type) => { + const cleanType = type.toLowerCase().trim(); + const matchedType = Object.entries(MODEL_TYPES).find(([key]) => + cleanType.includes(key) + ); + return matchedType ? matchedType[1].order : Infinity; +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/quickFilters.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/quickFilters.js new file mode 100644 index 0000000000000000000000000000000000000000..90d0a6566fe3bea208769e9fd0edccb1ad429283 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/quickFilters.js @@ -0,0 +1,42 @@ +export const QUICK_FILTER_PRESETS = [ + { + id: 'tiny_models', + label: 'Tiny', + shortDescription: 'Up to 10M parameters', + description: 'Lightweight models suitable for edge devices and quick experiments.', + filters: { + paramsRange: [0, 10], + selectedBooleanFilters: [] + } + }, + { + id: 'small_models', + label: 'Small', + shortDescription: '10-50M parameters', + description: 'Compact models balancing performance and efficiency.', + filters: { + paramsRange: [10, 50], + selectedBooleanFilters: [] + } + }, + { + id: 'medium_models', + label: 'Medium', + shortDescription: '50-200M parameters', + description: 'Mid-range models with good capacity for complex EEG tasks.', + filters: { + paramsRange: [50, 200], + selectedBooleanFilters: [] + } + }, + { + id: 'large_models', + label: 'Large', + shortDescription: '200M+ parameters', + description: 'Large-scale models offering the best performance but requiring more resources.', + filters: { + paramsRange: [200, 501], + selectedBooleanFilters: [] + } + } +]; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/tooltips.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/tooltips.js new file mode 100644 index 0000000000000000000000000000000000000000..d20833df1aa7de045f2c5fbbf74bd6b955b0115f --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/constants/tooltips.js @@ -0,0 +1,268 @@ +import { Box, Typography } from "@mui/material"; + +const createTooltipContent = (title, items) => ( + + + {title} + + + {items.map(({ label, description, subItems }, index) => ( +
  • + + {label}: {description} + {subItems && ( + + {subItems.map((item, subIndex) => ( +
  • + + {item} + +
  • + ))} +
    + )} + + + ))} +
    + +); + +export const COLUMN_TOOLTIPS = { + AVERAGE: createTooltipContent("Average score across all benchmarks:", [ + { + label: "Calculation", + description: "Weighted average of normalized scores from all benchmarks", + subItems: [ + "Each benchmark is normalized to a 0-100 scale", + "All normalised benchmarks are then averaged together", + ], + }, + ]), + + ANLI: createTooltipContent("Adversarial Natural Language Inference (ANLI):", [ + { + label: "Purpose", + description: + "Tests the model's ability to perform natural language inference on adversarially constructed examples", + subItems: [ + "Entailment, contradiction, and neutral classification", + "Adversarially-mined examples for robustness", + ], + }, + { + label: "Scoring: Accuracy", + description: "Was the correct label predicted for each example.", + }, + ]), + + LOGIQA: createTooltipContent("Logical Reasoning QA (LogiQA):", [ + { + label: "Purpose", + description: + "Evaluates logical reasoning abilities through multiple-choice questions", + subItems: [ + "Categorical reasoning", + "Sufficient conditional reasoning", + "Necessary conditional reasoning", + "Disjunctive reasoning", + "Conjunctive reasoning", + ], + }, + { + label: "Scoring: Accuracy", + description: + "Was the correct choice selected among the options.", + }, + ]), + + ARCHITECTURE: createTooltipContent("Model Architecture Information:", [ + { + label: "Definition", + description: "The fundamental structure and design of the EEG model", + subItems: [ + "Pretrained: Base EEG models trained with self-supervised learning on raw EEG data.", + "Fine-tuned: Models fine-tuned on specific EEG datasets for particular downstream tasks.", + "Task-specific: Models designed for specific EEG tasks such as sleep staging, motor imagery, or seizure detection.", + "Foundation: Large-scale EEG foundation models trained on diverse EEG datasets.", + ], + }, + { + label: "Impact", + description: "How architecture affects model capabilities", + subItems: [ + "Pretrained models provide general EEG representations but may need fine-tuning for specific tasks.", + "Fine-tuned models are optimized for particular datasets or paradigms.", + "Task-specific models achieve strong performance on their target task but may not generalize.", + "Foundation models aim for broad generalization across EEG tasks.", + ], + }, + ]), + + PRECISION: createTooltipContent("Numerical Precision Format:", [ + { + label: "Overview", + description: + "Data format used to store model weights and perform computations", + subItems: [ + "bfloat16: Half precision (Brain Float format), good for stability", + "float16: Half precision", + "8bit/4bit: Quantized formats, for efficiency", + "GPTQ/AWQ: Quantized methods", + ], + }, + { + label: "Impact", + description: "How precision affects model deployment", + subItems: [ + "Higher precision = better accuracy but more memory usage", + "Lower precision = faster inference and smaller size", + "Trade-off between model quality and resource usage", + ], + }, + ]), + + FLAGS: createTooltipContent("Model Flags:", [ + { + label: "Filters", + subItems: [ + "Unavailable: No longer on the hub (private, deleted) or missing a license tag", + ], + }, + ]), + + PARAMETERS: createTooltipContent("Model Parameters:", [ + { + label: "Measurement", + description: "Total number of trainable parameters in millions", + subItems: [ + "Indicates model capacity and complexity", + "Correlates with computational requirements", + "Influences memory usage and inference speed", + ], + }, + ]), + + LICENSE: createTooltipContent("Model License Information:", [ + { + label: "Importance", + description: "Legal terms governing model usage and distribution", + subItems: [ + "Commercial vs non-commercial use", + "Attribution requirements", + "Modification and redistribution rights", + "Liability and warranty terms", + ], + }, + ]), +}; + +export const UI_TOOLTIPS = { + COLUMN_SELECTOR: "Choose which columns to display in the table", + DISPLAY_OPTIONS: createTooltipContent("Table Display Options", [ + { + label: "Overview", + description: "Configure how the table displays data and information", + subItems: [ + "Row size and layout", + "Score display format", + "Ranking calculation", + "Average score computation", + ], + }, + ]), + SEARCH_BAR: createTooltipContent("Advanced Model Search", [ + { + label: "Name Search", + description: "Search directly by model name", + subItems: [ + "Supports regular expressions (e.g., ^eegnet.*v4)", + "Case sensitive", + ], + }, + { + label: "Field Search", + description: "Use @field:value syntax for precise filtering", + subItems: [ + "@architecture:eegnet - Filter by architecture", + "@license:mit - Filter by license", + "@precision:float16 - Filter by precision", + "@type:pretrained - Filter by model type", + ], + }, + { + label: "Multiple Searches", + description: "Combine multiple criteria using semicolons", + subItems: [ + "braindecode @license:mit; @architecture:eegnet", + "^shallow.*net; @precision:float16", + ], + }, + ]), + QUICK_FILTERS: createTooltipContent( + "Filter models based on their size:", + [ + { + label: "Tiny (Up to 10M)", + description: + "Lightweight models suitable for edge devices and quick experiments.", + }, + { + label: "Small (10M-50M)", + description: + "Compact models balancing performance and efficiency.", + }, + { + label: "Medium (50M-200M)", + description: + "Mid-range models with good capacity for complex EEG tasks.", + }, + { + label: "Large (200M+)", + description: + "Large-scale models offering the best performance but requiring more resources.", + }, + ] + ), + ROW_SIZE: { + title: "Row Size", + description: + "Adjust the height of table rows. Compact is ideal for viewing more data at once, while Large provides better readability and touch targets.", + }, + SCORE_DISPLAY: { + title: "Score Display", + description: + "Choose between normalized scores (0-100% scale for easy comparison) or raw scores (actual benchmark results). Normalized scores help compare performance across different benchmarks, while raw scores show actual benchmark outputs.", + }, + RANKING_MODE: { + title: "Ranking Mode", + description: + "Choose between static ranking (original position in the full leaderboard) or dynamic ranking (position based on current filters and sorting).", + }, + AVERAGE_SCORE: { + title: "Average Score Calculation", + description: + "Define how the average score is calculated. 'All Scores' uses all benchmarks, while 'Visible Only' calculates the average using only the visible benchmark columns.", + }, +}; + +export const getTooltipStyle = {}; + +export const TABLE_TOOLTIPS = { + HUB_LINK: (modelName) => `View ${modelName} on Hugging Face Hub`, + EVAL_RESULTS: (modelName) => + `View detailed evaluation results for ${modelName}`, + POSITION_CHANGE: (change) => + `${Math.abs(change)} position${Math.abs(change) > 1 ? "s" : ""} ${ + change > 0 ? "up" : "down" + }`, + METADATA: { + TYPE: (type) => type || "-", + ARCHITECTURE: (arch) => arch || "-", + PRECISION: (precision) => precision || "-", + LICENSE: (license) => license || "-", + UPLOAD_DATE: (date) => date || "-", + SUBMISSION_DATE: (date) => date || "-", + BASE_MODEL: (model) => model || "-", + }, +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/context/LeaderboardContext.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/context/LeaderboardContext.js new file mode 100644 index 0000000000000000000000000000000000000000..df394f4382d671cb642880cabfd8f981a7cc6ee2 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/context/LeaderboardContext.js @@ -0,0 +1,687 @@ +import React, { + createContext, + useContext, + useReducer, + useEffect, + useMemo, + useCallback, +} from "react"; +import { useSearchParams, useLocation } from "react-router-dom"; +import { MODEL_TYPE_ORDER } from "../constants/modelTypes"; +import { FILTER_PRECISIONS, TABLE_DEFAULTS } from "../constants/defaults"; + +// Create context +const LeaderboardContext = createContext(); + +// Define default filter values +const DEFAULT_FILTERS = { + search: "", + precisions: FILTER_PRECISIONS, + types: MODEL_TYPE_ORDER, + paramsRange: [-1, 500], + booleanFilters: [], +}; + +// Define default display values +const DEFAULT_DISPLAY = { + rowSize: TABLE_DEFAULTS.ROW_SIZE, + scoreDisplay: TABLE_DEFAULTS.SCORE_DISPLAY, + averageMode: TABLE_DEFAULTS.AVERAGE_MODE, + rankingMode: TABLE_DEFAULTS.RANKING_MODE, + visibleColumns: TABLE_DEFAULTS.COLUMNS.DEFAULT_VISIBLE, +}; + +// Create initial counter structure +const createInitialCounts = () => { + const modelTypes = {}; + MODEL_TYPE_ORDER.forEach((type) => { + modelTypes[type] = 0; + }); + + const precisions = {}; + FILTER_PRECISIONS.forEach((precision) => { + precisions[precision] = 0; + }); + + return { + modelTypes, + precisions, + notOnHub: 0, + parameterRanges: { + tiny: 0, + small: 0, + medium: 0, + large: 0, + }, + }; +}; + +// Define initial state +const initialState = { + models: [], + loading: true, + countsReady: false, + error: null, + filters: DEFAULT_FILTERS, + display: DEFAULT_DISPLAY, + filtersExpanded: false, + pinnedModels: [], + filterCounts: { + normal: createInitialCounts(), + }, +}; + +// Function to normalize parameter value +const normalizeParams = (params) => { + const numParams = Number(params); + if (isNaN(numParams)) return null; + return Math.round(numParams * 100) / 100; +}; + +// Function to check if a parameter count is within a range +const isInParamRange = (params, range) => { + if (range[0] === -1 && range[1] === 500) return true; + const normalizedParams = normalizeParams(params); + if (normalizedParams === null) return false; + return normalizedParams >= range[0] && normalizedParams < range[1]; +}; + +// Function to check if a model matches filter criteria +const modelMatchesFilters = (model, filters) => { + // Filter by precision + if ( + filters.precisions.length > 0 && + !filters.precisions.includes(model.model.precision) + ) { + return false; + } + + // Filter by type + if (filters.types.length > 0) { + const modelType = model.model.type?.toLowerCase().trim(); + if (!filters.types.some((type) => modelType?.includes(type))) { + return false; + } + } + + // Filter by parameters (in millions for EEG) + const params = Number( + model.metadata?.params_millions || model.features?.params_millions + ); + if (!isInParamRange(params, filters.paramsRange)) return false; + + // Filter by search + if (filters.search) { + const searchLower = filters.search.toLowerCase(); + const modelName = model.model.name.toLowerCase(); + if (!modelName.includes(searchLower)) return false; + } + + // Boolean filters + if (filters.booleanFilters.length > 0) { + return filters.booleanFilters.every((filter) => { + const filterValue = typeof filter === "object" ? filter.value : filter; + + if (filterValue === "is_not_available_on_hub") { + return model.features[filterValue]; + } + + return !model.features[filterValue]; + }); + } + + return true; +}; + +// Function to calculate filtered model counts +const calculateFilteredCounts = ( + allRows, + totalPinnedCount, + filters, + filteredCount +) => { + // If no table, use raw filteredCount + if (!allRows) { + return { + currentFilteredCount: + typeof filteredCount === "number" ? filteredCount : 0, + totalPinnedCount: totalPinnedCount || 0, + }; + } + + // 1. Total number of rows (models matching filters) + const totalFilteredCount = allRows.length; + + // 2. Number of pinned models that also match filters + const pinnedMatchingFilters = allRows.filter((row) => { + const model = row.original; + return model.isPinned && modelMatchesFilters(model, filters); + }).length; + + return { + currentFilteredCount: totalFilteredCount - pinnedMatchingFilters, + totalPinnedCount: totalPinnedCount || 0, + }; +}; + +// Function to calculate counters +const calculateModelCounts = (models) => { + const normalCounts = createInitialCounts(); + + models.forEach((model) => { + // Model type + if (model.model?.type) { + const cleanType = model.model.type.toLowerCase().trim(); + const matchedType = MODEL_TYPE_ORDER.find((key) => + cleanType.includes(key) + ); + if (matchedType) { + normalCounts.modelTypes[matchedType]++; + } + } + + // Precision + if (model.model?.precision) { + normalCounts.precisions[model.model.precision]++; + } + + // Hub availability + if ( + !( + model.features?.is_not_available_on_hub || + model.metadata?.is_not_available_on_hub + ) + ) + normalCounts.notOnHub++; + + // Parameter ranges (in millions for EEG) + const params = Number( + model.metadata?.params_millions || model.features?.params_millions + ); + if (!isNaN(params)) { + if (isInParamRange(params, [0, 10])) normalCounts.parameterRanges.tiny++; + if (isInParamRange(params, [10, 50])) normalCounts.parameterRanges.small++; + if (isInParamRange(params, [50, 200])) normalCounts.parameterRanges.medium++; + if (isInParamRange(params, [200, 501])) normalCounts.parameterRanges.large++; + } + }); + + return { + normal: normalCounts, + }; +}; + +// Define reducer +const reducer = (state, action) => { + switch (action.type) { + case "SET_MODELS": + const newCounts = calculateModelCounts(action.payload); + return { + ...state, + models: action.payload, + filterCounts: newCounts, + countsReady: true, + loading: false, + }; + + case "SET_LOADING": + return { + ...state, + loading: action.payload, + ...(action.payload ? { countsReady: false } : {}), + }; + + case "SET_ERROR": + return { + ...state, + error: action.payload, + loading: false, + }; + + case "SET_FILTER": + return { + ...state, + filters: { + ...state.filters, + [action.key]: action.value, + }, + }; + + case "SET_DISPLAY_OPTION": + return { + ...state, + display: { + ...state.display, + [action.key]: action.value, + }, + }; + + case "TOGGLE_PINNED_MODEL": + const modelKey = action.payload; + const pinnedModels = [...state.pinnedModels]; + const modelIndex = pinnedModels.indexOf(modelKey); + + if (modelIndex === -1) { + pinnedModels.push(modelKey); + } else { + pinnedModels.splice(modelIndex, 1); + } + + return { + ...state, + pinnedModels, + }; + + case "SET_PINNED_MODELS": + return { + ...state, + pinnedModels: action.payload, + }; + + case "TOGGLE_FILTERS_EXPANDED": + return { + ...state, + filtersExpanded: !state.filtersExpanded, + }; + + case "RESET_FILTERS": + return { + ...state, + filters: DEFAULT_FILTERS, + }; + + case "RESET_ALL": + return { + ...state, + filters: DEFAULT_FILTERS, + display: DEFAULT_DISPLAY, + pinnedModels: [], + }; + + default: + return state; + } +}; + +// Provider component +const LeaderboardProvider = ({ children }) => { + const [state, dispatch] = useReducer(reducer, initialState); + const [searchParams, setSearchParams] = useSearchParams(); + const location = useLocation(); + + // Effect to load initial values from URL + useEffect(() => { + // Skip URL sync if we're resetting + if (location.state?.skipUrlSync) return; + + const loadFromUrl = () => { + // Load filters + const searchFromUrl = searchParams.get("search"); + if (searchFromUrl) { + dispatch({ type: "SET_FILTER", key: "search", value: searchFromUrl }); + } + + const paramsFromUrl = searchParams.get("params")?.split(",").map(Number); + if (paramsFromUrl?.length === 2) { + dispatch({ + type: "SET_FILTER", + key: "paramsRange", + value: paramsFromUrl, + }); + } + + const filtersFromUrl = + searchParams.get("filters")?.split(",").filter(Boolean) || []; + if (filtersFromUrl.length > 0) { + dispatch({ + type: "SET_FILTER", + key: "booleanFilters", + value: filtersFromUrl, + }); + } + + const precisionsFromUrl = searchParams + .get("precision") + ?.split(",") + .filter(Boolean); + if (precisionsFromUrl) { + dispatch({ + type: "SET_FILTER", + key: "precisions", + value: precisionsFromUrl, + }); + } + + const typesFromUrl = searchParams + .get("types") + ?.split(",") + .filter(Boolean); + if (typesFromUrl) { + dispatch({ type: "SET_FILTER", key: "types", value: typesFromUrl }); + } + + // Load pinned models + const pinnedFromUrl = + searchParams.get("pinned")?.split(",").filter(Boolean) || []; + if (pinnedFromUrl.length > 0) { + dispatch({ type: "SET_PINNED_MODELS", payload: pinnedFromUrl }); + } + + // Load visible columns + const columnsFromUrl = searchParams + .get("columns") + ?.split(",") + .filter(Boolean); + if (columnsFromUrl) { + dispatch({ + type: "SET_DISPLAY_OPTION", + key: "visibleColumns", + value: columnsFromUrl, + }); + } + + // Load table options + const rowSizeFromUrl = searchParams.get("rowSize"); + if (rowSizeFromUrl) { + dispatch({ + type: "SET_DISPLAY_OPTION", + key: "rowSize", + value: rowSizeFromUrl, + }); + } + + const scoreDisplayFromUrl = searchParams.get("scoreDisplay"); + if (scoreDisplayFromUrl) { + dispatch({ + type: "SET_DISPLAY_OPTION", + key: "scoreDisplay", + value: scoreDisplayFromUrl, + }); + } + + const averageModeFromUrl = searchParams.get("averageMode"); + if (averageModeFromUrl) { + dispatch({ + type: "SET_DISPLAY_OPTION", + key: "averageMode", + value: averageModeFromUrl, + }); + } + + const rankingModeFromUrl = searchParams.get("rankingMode"); + if (rankingModeFromUrl) { + dispatch({ + type: "SET_DISPLAY_OPTION", + key: "rankingMode", + value: rankingModeFromUrl, + }); + } + }; + + loadFromUrl(); + }, [searchParams, location.state]); + + // Effect to synchronize filters with URL + useEffect(() => { + // Skip URL sync if we're resetting + if (location.state?.skipUrlSync) return; + + const newSearchParams = new URLSearchParams(searchParams); + const currentParams = searchParams.get("params")?.split(",").map(Number); + const currentFilters = + searchParams.get("filters")?.split(",").filter(Boolean) || []; + const currentSearch = searchParams.get("search"); + const currentPinned = + searchParams.get("pinned")?.split(",").filter(Boolean) || []; + const currentColumns = + searchParams.get("columns")?.split(",").filter(Boolean) || []; + const currentRowSize = searchParams.get("rowSize"); + const currentScoreDisplay = searchParams.get("scoreDisplay"); + const currentAverageMode = searchParams.get("averageMode"); + const currentRankingMode = searchParams.get("rankingMode"); + const currentPrecisions = + searchParams.get("precision")?.split(",").filter(Boolean) || []; + const currentTypes = + searchParams.get("types")?.split(",").filter(Boolean) || []; + + // Only update URL if values have changed + const paramsChanged = + !currentParams || + currentParams[0] !== state.filters.paramsRange[0] || + currentParams[1] !== state.filters.paramsRange[1]; + + const filtersChanged = + state.filters.booleanFilters.length !== currentFilters.length || + state.filters.booleanFilters.some((f) => !currentFilters.includes(f)); + + const searchChanged = state.filters.search !== currentSearch; + + const pinnedChanged = + state.pinnedModels.length !== currentPinned.length || + state.pinnedModels.some((m) => !currentPinned.includes(m)); + + const columnsChanged = + state.display.visibleColumns.length !== currentColumns.length || + state.display.visibleColumns.some((c) => !currentColumns.includes(c)); + + const rowSizeChanged = state.display.rowSize !== currentRowSize; + const scoreDisplayChanged = + state.display.scoreDisplay !== currentScoreDisplay; + const averageModeChanged = state.display.averageMode !== currentAverageMode; + const rankingModeChanged = state.display.rankingMode !== currentRankingMode; + const precisionsChanged = + state.filters.precisions.length !== currentPrecisions.length || + state.filters.precisions.some((p) => !currentPrecisions.includes(p)); + const typesChanged = + state.filters.types.length !== currentTypes.length || + state.filters.types.some((t) => !currentTypes.includes(t)); + + if (paramsChanged) { + if ( + state.filters.paramsRange[0] !== -1 || + state.filters.paramsRange[1] !== 500 + ) { + newSearchParams.set("params", state.filters.paramsRange.join(",")); + } else { + newSearchParams.delete("params"); + } + } + + if (filtersChanged) { + if (state.filters.booleanFilters.length > 0) { + newSearchParams.set("filters", state.filters.booleanFilters.join(",")); + } else { + newSearchParams.delete("filters"); + } + } + + if (searchChanged) { + if (state.filters.search) { + newSearchParams.set("search", state.filters.search); + } else { + newSearchParams.delete("search"); + } + } + + if (pinnedChanged) { + if (state.pinnedModels.length > 0) { + newSearchParams.set("pinned", state.pinnedModels.join(",")); + } else { + newSearchParams.delete("pinned"); + } + } + + if (columnsChanged) { + if ( + JSON.stringify([...state.display.visibleColumns].sort()) !== + JSON.stringify([...TABLE_DEFAULTS.COLUMNS.DEFAULT_VISIBLE].sort()) + ) { + newSearchParams.set("columns", state.display.visibleColumns.join(",")); + } else { + newSearchParams.delete("columns"); + } + } + + if (rowSizeChanged) { + if (state.display.rowSize !== TABLE_DEFAULTS.ROW_SIZE) { + newSearchParams.set("rowSize", state.display.rowSize); + } else { + newSearchParams.delete("rowSize"); + } + } + + if (scoreDisplayChanged) { + if (state.display.scoreDisplay !== TABLE_DEFAULTS.SCORE_DISPLAY) { + newSearchParams.set("scoreDisplay", state.display.scoreDisplay); + } else { + newSearchParams.delete("scoreDisplay"); + } + } + + if (averageModeChanged) { + if (state.display.averageMode !== TABLE_DEFAULTS.AVERAGE_MODE) { + newSearchParams.set("averageMode", state.display.averageMode); + } else { + newSearchParams.delete("averageMode"); + } + } + + if (rankingModeChanged) { + if (state.display.rankingMode !== TABLE_DEFAULTS.RANKING_MODE) { + newSearchParams.set("rankingMode", state.display.rankingMode); + } else { + newSearchParams.delete("rankingMode"); + } + } + + if (precisionsChanged) { + if ( + JSON.stringify([...state.filters.precisions].sort()) !== + JSON.stringify([...FILTER_PRECISIONS].sort()) + ) { + newSearchParams.set("precision", state.filters.precisions.join(",")); + } else { + newSearchParams.delete("precision"); + } + } + + if (typesChanged) { + if ( + JSON.stringify([...state.filters.types].sort()) !== + JSON.stringify([...MODEL_TYPE_ORDER].sort()) + ) { + newSearchParams.set("types", state.filters.types.join(",")); + } else { + newSearchParams.delete("types"); + } + } + + if ( + paramsChanged || + filtersChanged || + searchChanged || + pinnedChanged || + columnsChanged || + rowSizeChanged || + scoreDisplayChanged || + averageModeChanged || + rankingModeChanged || + precisionsChanged || + typesChanged + ) { + // Update search params and let HashRouter handle the URL + setSearchParams(newSearchParams); + } + }, [state, searchParams, location.state]); + + const actions = useMemo( + () => ({ + setModels: (models) => dispatch({ type: "SET_MODELS", payload: models }), + setLoading: (loading) => + dispatch({ type: "SET_LOADING", payload: loading }), + setError: (error) => dispatch({ type: "SET_ERROR", payload: error }), + setFilter: (key, value) => dispatch({ type: "SET_FILTER", key, value }), + setDisplayOption: (key, value) => + dispatch({ type: "SET_DISPLAY_OPTION", key, value }), + togglePinnedModel: (modelKey) => + dispatch({ type: "TOGGLE_PINNED_MODEL", payload: modelKey }), + toggleFiltersExpanded: () => + dispatch({ type: "TOGGLE_FILTERS_EXPANDED" }), + resetFilters: () => { + dispatch({ type: "RESET_FILTERS" }); + const newParams = new URLSearchParams(searchParams); + [ + "filters", + "params", + "precision", + "types", + "search", + ].forEach((param) => { + newParams.delete(param); + }); + setSearchParams(newParams); + }, + resetAll: () => { + // Reset all state + dispatch({ type: "RESET_ALL" }); + // Clear all URL params with skipUrlSync flag + setSearchParams({}, { state: { skipUrlSync: true } }); + }, + }), + [searchParams, setSearchParams] + ); + + // Function to calculate counts (exposed via context) + const getFilteredCounts = useCallback( + (allRows, totalPinnedCount, filteredCount) => { + return calculateFilteredCounts( + allRows, + totalPinnedCount, + state.filters, + filteredCount + ); + }, + [state.filters] + ); + + // Also expose filtering function for reuse elsewhere + const checkModelMatchesFilters = useCallback( + (model) => { + return modelMatchesFilters(model, state.filters); + }, + [state.filters] + ); + + const value = useMemo( + () => ({ + state: { + ...state, + loading: state.loading || !state.countsReady, + }, + actions, + utils: { + getFilteredCounts, + checkModelMatchesFilters, + }, + }), + [state, actions, getFilteredCounts, checkModelMatchesFilters] + ); + + return ( + + {children} + + ); +}; + +// Hook to use context +const useLeaderboard = () => { + const context = useContext(LeaderboardContext); + if (!context) { + throw new Error("useLeaderboard must be used within a LeaderboardProvider"); + } + return context; +}; + +export { useLeaderboard }; +export default LeaderboardProvider; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useBatchedState.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useBatchedState.js new file mode 100644 index 0000000000000000000000000000000000000000..ad11c91393ca9e413853ae440154b948293103e9 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useBatchedState.js @@ -0,0 +1,31 @@ +import { useState, useCallback, useTransition } from 'react'; + +export const useBatchedState = (initialState, options = {}) => { + const { batchDelay = 0, useTransitions = false } = options; + const [state, setState] = useState(typeof initialState === 'function' ? initialState() : initialState); + const [isPending, startTransition] = useTransition(); + + const setBatchedState = useCallback((newState) => { + if (useTransitions) { + startTransition(() => { + if (batchDelay > 0) { + setTimeout(() => { + setState(newState); + }, batchDelay); + } else { + setState(newState); + } + }); + } else { + if (batchDelay > 0) { + setTimeout(() => { + setState(newState); + }, batchDelay); + } else { + setState(newState); + } + } + }, [batchDelay, useTransitions]); + + return [state, setBatchedState, isPending]; +}; \ No newline at end of file diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useDataUtils.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useDataUtils.js new file mode 100644 index 0000000000000000000000000000000000000000..5313812d5341be56ff4539c6be0cc1f98efcd2d4 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useDataUtils.js @@ -0,0 +1,310 @@ +import { useMemo } from "react"; +import { + looksLikeRegex, + parseSearchQuery, + getValueByPath, +} from "../utils/searchUtils"; +import { MODEL_TYPE_ORDER } from "../constants/modelTypes"; + +// Calculate min/max averages +export const useAverageRange = (data) => { + return useMemo(() => { + const averages = data.map((item) => item.model.average_score); + return { + minAverage: Math.min(...averages), + maxAverage: Math.max(...averages), + }; + }, [data]); +}; + +// Generate colors for scores +export const useColorGenerator = (minAverage, maxAverage) => { + return useMemo(() => { + const colorCache = new Map(); + return (value) => { + const cached = colorCache.get(value); + if (cached) return cached; + + const normalizedValue = (value - minAverage) / (maxAverage - minAverage); + const red = Math.round(255 * (1 - normalizedValue) * 1); + const green = Math.round(255 * normalizedValue) * 1; + const color = `rgba(${red}, ${green}, 0, 1)`; + colorCache.set(value, color); + return color; + }; + }, [minAverage, maxAverage]); +}; + +// Process data with boolean standardization +export const useProcessedData = (data, averageMode, visibleColumns) => { + return useMemo(() => { + let processed = data.map((item) => { + const evaluationScores = Object.entries(item.evaluations) + .filter(([key]) => { + if (averageMode === "all") return true; + return visibleColumns.includes(`evaluations.${key}.normalized_score`); + }) + .map(([, value]) => value.normalized_score); + + const average = + evaluationScores.length > 0 + ? evaluationScores.reduce((a, b) => a + b, 0) / + evaluationScores.length + : averageMode === "visible" + ? null + : 0; + + // Boolean standardization + const standardizedFeatures = { + ...item.features, + is_moe: Boolean(item.features.is_moe), + is_flagged: Boolean(item.features.is_flagged), + is_official_provider: Boolean(item.features.is_official_provider), + is_merged: Boolean(item.features.is_merged), + is_not_available_on_hub: Boolean(item.features.is_not_available_on_hub), + }; + + return { + ...item, + features: standardizedFeatures, + model: { + ...item.model, + has_chat_template: Boolean(item.model.has_chat_template), + average_score: average, + }, + }; + }); + + processed.sort((a, b) => { + if (a.model.average_score === null && b.model.average_score === null) + return 0; + if (a.model.average_score === null) return 1; + if (b.model.average_score === null) return -1; + return b.model.average_score - a.model.average_score; + }); + + const result = processed.map((item, index) => ({ + ...item, + static_rank: index + 1, + })); + + return result; + }, [data, averageMode, visibleColumns]); +}; + +// Common filtering logic +export const useFilteredData = ( + processedData, + selectedPrecisions, + selectedTypes, + paramsRange, + searchValue, + selectedBooleanFilters, + rankingMode, + pinnedModels = [], + isOfficialProviderActive = false +) => { + return useMemo(() => { + const pinnedData = processedData.filter((row) => { + return pinnedModels.includes(row.id); + }); + + const unpinnedData = processedData.filter((row) => { + return !pinnedModels.includes(row.id); + }); + + let filteredUnpinned = unpinnedData; + + // Filter by official providers + if (isOfficialProviderActive) { + filteredUnpinned = filteredUnpinned.filter( + (row) => + row.features?.is_official_provider || + row.metadata?.is_official_provider + ); + } + + // Filter by precision + if (selectedPrecisions.length > 0) { + filteredUnpinned = filteredUnpinned.filter((row) => + selectedPrecisions.includes(row.model.precision) + ); + } + + // Filter by type + if ( + selectedTypes.length > 0 && + selectedTypes.length < MODEL_TYPE_ORDER.length + ) { + filteredUnpinned = filteredUnpinned.filter((row) => { + const modelType = row.model.type?.toLowerCase().trim(); + return selectedTypes.some((type) => modelType?.includes(type)); + }); + } + + // Filter by parameters + if (!(paramsRange[0] === -1 && paramsRange[1] === 140)) { + filteredUnpinned = filteredUnpinned.filter((row) => { + const params = + row.metadata?.params_billions || row.features?.params_billions; + if (params === undefined || params === null) return false; + return params >= paramsRange[0] && params < paramsRange[1]; + }); + } + + // Filter by search + if (searchValue) { + const searchQueries = searchValue + .split(";") + .map((q) => q.trim()) + .filter((q) => q); + if (searchQueries.length > 0) { + filteredUnpinned = filteredUnpinned.filter((row) => { + return searchQueries.some((query) => { + const { specialSearches, textSearch } = parseSearchQuery(query); + + const specialSearchMatch = specialSearches.every( + ({ field, value }) => { + const fieldValue = getValueByPath(row, field) + ?.toString() + .toLowerCase(); + return fieldValue?.includes(value.toLowerCase()); + } + ); + + if (!specialSearchMatch) return false; + if (!textSearch) return true; + + const modelName = row.model.name.toLowerCase(); + const searchLower = textSearch.toLowerCase(); + + if (looksLikeRegex(textSearch)) { + try { + const regex = new RegExp(textSearch, "i"); + return regex.test(modelName); + } catch (e) { + return modelName.includes(searchLower); + } + } else { + return modelName.includes(searchLower); + } + }); + }); + } + } + + // Filter by booleans + if (selectedBooleanFilters.length > 0) { + filteredUnpinned = filteredUnpinned.filter((row) => { + return selectedBooleanFilters.every((filter) => { + const filterValue = + typeof filter === "object" ? filter.value : filter; + + // Maintainer's Highlight keeps positive logic + if (filterValue === "is_official_provider") { + return row.features[filterValue]; + } + + // For all other filters, invert the logic + if (filterValue === "is_not_available_on_hub") { + return row.features[filterValue]; + } + + return !row.features[filterValue]; + }); + }); + } + + // Create ordered array of pinned models respecting pinnedModels order + const orderedPinnedData = pinnedModels + .map((pinnedModelId) => + pinnedData.find((item) => item.id === pinnedModelId) + ) + .filter(Boolean); + + // Combine all filtered data + const allFilteredData = [...filteredUnpinned, ...orderedPinnedData]; + + // Sort all data by average_score for dynamic_rank + const sortedByScore = [...allFilteredData].sort((a, b) => { + // Si les scores moyens sont différents, trier par score + if (a.model.average_score !== b.model.average_score) { + if (a.model.average_score === null && b.model.average_score === null) + return 0; + if (a.model.average_score === null) return 1; + if (b.model.average_score === null) return -1; + return b.model.average_score - a.model.average_score; + } + + // Si les scores sont égaux, comparer le nom du modèle et la date de soumission + if (a.model.name === b.model.name) { + // Si même nom, trier par date de soumission (la plus récente d'abord) + const dateA = new Date(a.metadata?.submission_date || 0); + const dateB = new Date(b.metadata?.submission_date || 0); + return dateB - dateA; + } + + // Si noms différents, trier par nom + return a.model.name.localeCompare(b.model.name); + }); + + // Create Map to store dynamic_ranks + const dynamicRankMap = new Map(); + sortedByScore.forEach((item, index) => { + dynamicRankMap.set(item.id, index + 1); + }); + + // Add ranks to final data + const finalData = [...orderedPinnedData, ...filteredUnpinned].map( + (item) => { + return { + ...item, + dynamic_rank: dynamicRankMap.get(item.id), + rank: item.isPinned + ? pinnedModels.indexOf(item.id) + 1 + : rankingMode === "static" + ? item.static_rank + : dynamicRankMap.get(item.id), + isPinned: pinnedModels.includes(item.id), + }; + } + ); + + return finalData; + }, [ + processedData, + selectedPrecisions, + selectedTypes, + paramsRange, + searchValue, + selectedBooleanFilters, + rankingMode, + pinnedModels, + isOfficialProviderActive, + ]); +}; + +// Column visibility management +export const useColumnVisibility = (visibleColumns = []) => { + // Create secure visibility object + const columnVisibility = useMemo(() => { + // Check visible columns + const safeVisibleColumns = Array.isArray(visibleColumns) + ? visibleColumns + : []; + + const visibility = {}; + try { + safeVisibleColumns.forEach((columnKey) => { + if (typeof columnKey === "string") { + visibility[columnKey] = true; + } + }); + } catch (error) { + console.warn("Error in useColumnVisibility:", error); + } + return visibility; + }, [visibleColumns]); + + return columnVisibility; +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useLeaderboardData.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useLeaderboardData.js new file mode 100644 index 0000000000000000000000000000000000000000..8642fe72c7aa443ac6b6cae2637b24a3b884432b --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/hooks/useLeaderboardData.js @@ -0,0 +1,127 @@ +import { useMemo, useRef, useState } from "react"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { useSearchParams } from "react-router-dom"; +import { useLeaderboard } from "../context/LeaderboardContext"; +import { useDataProcessing } from "../components/Table/hooks/useDataProcessing"; + +export const useLeaderboardData = () => { + const queryClient = useQueryClient(); + const [searchParams] = useSearchParams(); + const isInitialLoadRef = useRef(true); + + const { data, isLoading, error } = useQuery({ + queryKey: ["leaderboard"], + queryFn: async () => { + console.log("🔄 Starting API fetch attempt..."); + try { + console.log("🌐 Fetching from API..."); + const response = await fetch("/api/leaderboard/formatted"); + console.log("📡 API Response status:", response.status); + + if (!response.ok) { + const errorText = await response.text(); + console.error("🚨 API Error:", { + status: response.status, + statusText: response.statusText, + body: errorText, + }); + throw new Error(`HTTP error! status: ${response.status}`); + } + + const newData = await response.json(); + console.log("📥 Received data size:", JSON.stringify(newData).length); + return newData; + } catch (error) { + console.error("🔥 Detailed error:", { + name: error.name, + message: error.message, + stack: error.stack, + }); + throw error; + } + }, + refetchOnWindowFocus: false, + enabled: isInitialLoadRef.current || !!searchParams.toString(), + }); + + useMemo(() => { + if (data && isInitialLoadRef.current) { + console.log("🎯 Initial load complete"); + isInitialLoadRef.current = false; + } + }, [data]); + + return { + data, + isLoading, + error, + refetch: () => queryClient.invalidateQueries(["leaderboard"]), + }; +}; + +export const useLeaderboardProcessing = () => { + const { state, actions } = useLeaderboard(); + const [sorting, setSorting] = useState([ + { id: "model.average_score", desc: true }, + ]); + + const memoizedData = useMemo(() => state.models, [state.models]); + const memoizedFilters = useMemo( + () => ({ + search: state.filters.search, + precisions: state.filters.precisions, + types: state.filters.types, + paramsRange: state.filters.paramsRange, + booleanFilters: state.filters.booleanFilters, + isOfficialProviderActive: state.filters.isOfficialProviderActive, + }), + [ + state.filters.search, + state.filters.precisions, + state.filters.types, + state.filters.paramsRange, + state.filters.booleanFilters, + state.filters.isOfficialProviderActive, + ] + ); + + const { + table, + minAverage, + maxAverage, + getColorForValue, + processedData, + filteredData, + columns, + columnVisibility, + } = useDataProcessing( + memoizedData, + memoizedFilters.search, + memoizedFilters.precisions, + memoizedFilters.types, + memoizedFilters.paramsRange, + memoizedFilters.booleanFilters, + sorting, + state.display.rankingMode, + state.display.averageMode, + state.display.visibleColumns, + state.display.scoreDisplay, + state.pinnedModels, + actions.togglePinnedModel, + setSorting, + memoizedFilters.isOfficialProviderActive + ); + + return { + table, + minAverage, + maxAverage, + getColorForValue, + processedData, + filteredData, + columns, + columnVisibility, + loading: state.loading, + error: state.error, + }; +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/styles/common.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/styles/common.js new file mode 100644 index 0000000000000000000000000000000000000000..06648e526979fd7c992ea3f3721468e261448593 --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/styles/common.js @@ -0,0 +1,153 @@ +import { alpha } from "@mui/material"; + +export const commonStyles = { + // Tooltips + tooltip: { + sx: { + bgcolor: "background.tooltip", + "& .MuiTooltip-arrow": { + color: "background.tooltip", + }, + padding: "12px 16px", + maxWidth: 300, + fontSize: "0.875rem", + lineHeight: 1.4, + }, + }, + + // Progress bars + progressBar: { + position: "absolute", + left: -16, + top: -8, + height: "calc(100% + 16px)", + opacity: (theme) => (theme.palette.mode === "light" ? 0.1 : 0.2), + transition: "width 0.3s ease", + zIndex: 0, + }, + + // Cell containers + cellContainer: { + display: "flex", + alignItems: "center", + height: "100%", + width: "100%", + position: "relative", + }, + + // Hover effects + hoverEffect: (theme, isActive = false) => ({ + backgroundColor: isActive + ? alpha( + theme.palette.primary.main, + theme.palette.mode === "light" ? 0.08 : 0.16 + ) + : theme.palette.action.hover, + "& .MuiTypography-root": { + color: isActive ? "primary.main" : "text.primary", + }, + "& .MuiSvgIcon-root": { + color: isActive ? "primary.main" : "text.primary", + }, + }), + + // Filter groups + filterGroup: { + title: { + mb: 1, + fontSize: "0.8rem", + fontWeight: 700, + color: "text.primary", + display: "flex", + alignItems: "center", + gap: 0.5, + }, + container: { + display: "flex", + flexWrap: "wrap", + gap: 0.5, + alignItems: "center", + }, + }, + + // Option buttons (like in DisplayOptions) + optionButton: { + display: "flex", + alignItems: "center", + gap: 0.8, + cursor: "pointer", + padding: "4px 10px", + borderRadius: 1, + height: "32px", + "& .MuiSvgIcon-root": { + fontSize: "0.9rem", + }, + "& .MuiTypography-root": { + fontSize: "0.85rem", + }, + }, + + // Score indicators + scoreIndicator: { + dot: { + width: 10, + height: 10, + borderRadius: "50%", + marginLeft: -1, + }, + bar: { + position: "absolute", + left: -16, + top: -8, + height: "calc(100% + 16px)", + opacity: (theme) => (theme.palette.mode === "light" ? 0.1 : 0.2), + transition: "width 0.3s ease", + }, + }, + + // Popover content + popoverContent: { + p: 3, + width: 280, + maxHeight: 400, + overflowY: "auto", + }, +}; + +// Composant styles +export const componentStyles = { + // Table header cell + headerCell: { + borderRight: (theme) => + `1px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.05 : 0.1 + )}`, + "&:last-child": { + borderRight: "none", + }, + whiteSpace: "nowrap", + overflow: "hidden", + textOverflow: "ellipsis", + padding: "8px 16px", + backgroundColor: (theme) => theme.palette.background.paper, + position: "sticky !important", + top: 0, + zIndex: 10, + }, + + // Table cell + tableCell: { + borderRight: (theme) => + `1px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.05 : 0.1 + )}`, + "&:last-child": { + borderRight: "none", + }, + whiteSpace: "nowrap", + overflow: "hidden", + textOverflow: "ellipsis", + }, +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/utils/columnUtils.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/utils/columnUtils.js new file mode 100644 index 0000000000000000000000000000000000000000..526c015c6684569bc6581e1931b28e7dd09b3bac --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/utils/columnUtils.js @@ -0,0 +1,1073 @@ +import React from "react"; +import { Box, Typography, Link, Tooltip, IconButton } from "@mui/material"; +import { getModelTypeIcon } from "../constants/modelTypes"; +import TrendingUpIcon from "@mui/icons-material/TrendingUp"; +import TrendingDownIcon from "@mui/icons-material/TrendingDown"; +import RemoveIcon from "@mui/icons-material/Remove"; +import PushPinIcon from "@mui/icons-material/PushPin"; +import PushPinOutlinedIcon from "@mui/icons-material/PushPinOutlined"; +import { TABLE_DEFAULTS, HIGHLIGHT_COLORS } from "../constants/defaults"; +import { looksLikeRegex, extractTextSearch } from "./searchUtils"; +import { commonStyles } from "../styles/common"; +import { typeColumnSort } from "../components/Table/hooks/useSorting"; +import { + COLUMN_TOOLTIPS, + getTooltipStyle, + TABLE_TOOLTIPS, +} from "../constants/tooltips"; +import OpenInNewIcon from "@mui/icons-material/OpenInNew"; +import { alpha } from "@mui/material/styles"; +import InfoIconWithTooltip from "../../../../../components/shared/InfoIconWithTooltip"; + +const DatabaseIcon = () => ( + +); + +const HighlightedText = ({ text, searchValue }) => { + if (!searchValue) return text; + + const searches = searchValue + .split(";") + .map((s) => s.trim()) + .filter(Boolean); + let result = text; + let fragments = [{ text: result, isMatch: false }]; + + searches.forEach((search, searchIndex) => { + if (!search) return; + + try { + let regex; + if (looksLikeRegex(search)) { + regex = new RegExp(search, "gi"); + } else { + regex = new RegExp(search.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"), "gi"); + } + + const newFragments = []; + fragments.forEach((fragment) => { + if (fragment.isMatch) { + newFragments.push(fragment); + return; + } + + const parts = fragment.text.split(regex); + const matches = fragment.text.match(regex); + + if (!matches) { + newFragments.push(fragment); + return; + } + + parts.forEach((part, i) => { + if (part) newFragments.push({ text: part, isMatch: false }); + if (i < parts.length - 1) { + newFragments.push({ + text: matches[i], + isMatch: true, + colorIndex: searchIndex % HIGHLIGHT_COLORS.length, + }); + } + }); + }); + + fragments = newFragments; + } catch (e) { + console.warn("Invalid regex:", search); + } + }); + + return ( + <> + {fragments.map((fragment, i) => + fragment.isMatch ? ( + + theme.palette.getContrastText( + HIGHLIGHT_COLORS[fragment.colorIndex] + ), + fontWeight: 500, + px: 0.5, + py: "2px", + borderRadius: "3px", + mx: "1px", + overflow: "visible", + display: "inline-block", + }} + > + {fragment.text} + + ) : ( + {fragment.text} + ) + )} + + ); +}; + +const MEDAL_STYLES = { + 1: { + color: "#B58A1B", + background: "linear-gradient(135deg, #FFF7E0 0%, #FFD700 100%)", + borderColor: "rgba(212, 160, 23, 0.35)", + shadowColor: "rgba(212, 160, 23, 0.8)", + }, + 2: { + color: "#667380", + background: "linear-gradient(135deg, #FFFFFF 0%, #D8E3ED 100%)", + borderColor: "rgba(124, 139, 153, 0.35)", + shadowColor: "rgba(124, 139, 153, 0.8)", + }, + 3: { + color: "#B85C2F", + background: "linear-gradient(135deg, #FDF0E9 0%, #FFBC8C 100%)", + borderColor: "rgba(204, 108, 61, 0.35)", + shadowColor: "rgba(204, 108, 61, 0.8)", + }, +}; + +const getMedalStyle = (rank) => { + if (rank <= 3) { + const medalStyle = MEDAL_STYLES[rank]; + return { + color: medalStyle.color, + fontWeight: 900, + fontStretch: "150%", + fontFamily: '"Inter", -apple-system, sans-serif', + width: "24px", + height: "24px", + background: medalStyle.background, + border: "1px solid", + borderColor: medalStyle.borderColor, + borderRadius: "50%", + display: "flex", + alignItems: "center", + justifyContent: "center", + fontSize: "0.95rem", + lineHeight: 1, + padding: 0, + boxShadow: `1px 1px 0 ${medalStyle.shadowColor}`, + position: "relative", + }; + } + return { + color: "inherit", + fontWeight: rank <= 10 ? 600 : 400, + }; +}; + +const getRankStyle = (rank) => getMedalStyle(rank); + +const RankIndicator = ({ rank, previousRank, mode }) => { + const rankChange = previousRank ? previousRank - rank : 0; + + const RankChangeIndicator = ({ change }) => { + if (!change || mode === "dynamic") return null; + + const getChangeColor = (change) => { + if (change > 0) return "success.main"; + if (change < 0) return "error.main"; + return "grey.500"; + }; + + const getChangeIcon = (change) => { + if (change > 0) return ; + if (change < 0) return ; + return ; + }; + + return ( + 1 ? "s" : "" + } ${change > 0 ? "up" : "down"}`} + arrow + placement="right" + > + + {getChangeIcon(change)} + + + ); + }; + + return ( + + + {rank <= 3 ? ( + <> + + {rank} + + + + ) : ( + <> + + {rank} + + + + )} + + + ); +}; + +const getDetailsUrl = (modelName) => { + const formattedName = modelName.replace("/", "__"); + return `https://huggingface.co/datasets/open-llm-leaderboard/${formattedName}-details`; +}; + +const HeaderLabel = ({ label, tooltip, className, isSorted }) => ( + + + {label} + + +); + +const InfoIcon = ({ tooltip }) => ( + + + +); + +const createHeaderCell = (label, tooltip) => (header) => + ( + + + + + {tooltip && } + + + ); + +const createModelHeader = + (totalModels, officialProvidersCount = 0, isOfficialProviderActive = false) => + ({ table }) => { + return ( + + + + Model + + + + ); + }; + +const BooleanValue = ({ value }) => { + if (value === null || value === undefined) + return -; + + return ( + ({ + display: "flex", + alignItems: "center", + justifyContent: "center", + borderRadius: "4px", + px: 1, + py: 0.5, + backgroundColor: value + ? theme.palette.mode === "dark" + ? alpha(theme.palette.success.main, 0.1) + : alpha(theme.palette.success.main, 0.1) + : theme.palette.mode === "dark" + ? alpha(theme.palette.error.main, 0.1) + : alpha(theme.palette.error.main, 0.1), + })} + > + ({ + color: value + ? theme.palette.mode === "dark" + ? theme.palette.success.light + : theme.palette.success.dark + : theme.palette.mode === "dark" + ? theme.palette.error.light + : theme.palette.error.dark, + })} + > + {value ? "Yes" : "No"} + + + ); +}; + +export const createColumns = ( + getColorForValue, + scoreDisplay = "normalized", + columnVisibility = {}, + totalModels, + averageMode = "all", + searchValue = "", + rankingMode = "static", + onTogglePin, + hasPinnedRows = false +) => { + // Ajuster les tailles des colonnes en fonction de la présence de lignes épinglées + const getColumnSize = (defaultSize) => + hasPinnedRows ? "auto" : `${defaultSize}px`; + + const baseColumns = [ + { + accessorKey: "isPinned", + header: () => null, + cell: ({ row }) => ( + + { + e.stopPropagation(); + e.preventDefault(); + onTogglePin(row.original.id); + }} + sx={{ + padding: 0.5, + color: row.original.isPinned ? "primary.main" : "grey.400", + "&:hover": { + color: "primary.main", + }, + }} + > + {row.original.isPinned ? ( + + ) : ( + + )} + + + ), + enableSorting: false, + size: getColumnSize(40), + }, + { + accessorKey: "rank", + header: createHeaderCell("Rank"), + cell: ({ row }) => { + const rank = + rankingMode === "static" + ? row.original.static_rank + : row.original.dynamic_rank; + + return ( + + ); + }, + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["rank"], + }, + { + id: "model_type", + accessorFn: (row) => row.model.type, + header: createHeaderCell("Type"), + sortingFn: typeColumnSort, + cell: ({ row }) => ( + + + + {getModelTypeIcon(row.original.model.type)} + + + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["model.type_icon"], + }, + { + accessorKey: "id", + header: createModelHeader(totalModels), + cell: ({ row }) => { + const textSearch = extractTextSearch(searchValue); + const modelName = row.original.model.name; + + return ( + + + + theme.palette.mode === "dark" + ? theme.palette.info.light + : theme.palette.info.dark, + "& svg": { + opacity: 0.8, + }, + }, + overflow: "hidden", + textOverflow: "ellipsis", + whiteSpace: "nowrap", + flex: 1, + minWidth: 0, + fontWeight: row.original.static_rank <= 3 ? 600 : "inherit", + }} + > + + + + + + + + + ); + }, + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["id"], + }, + { + accessorKey: "model.average_score", + header: createHeaderCell("Average", COLUMN_TOOLTIPS.AVERAGE), + cell: ({ row, getValue }) => + createScoreCell(getValue, row, "model.average_score"), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["model.average_score"], + meta: { + headerStyle: { + borderLeft: (theme) => + `2px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.1 : 0.2 + )}`, + borderRight: (theme) => + `2px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.1 : 0.2 + )}`, + }, + cellStyle: (value) => ({ + position: "relative", + overflow: "hidden", + padding: "8px 16px", + borderLeft: (theme) => + `2px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.1 : 0.2 + )}`, + borderRight: (theme) => + `2px solid ${alpha( + theme.palette.divider, + theme.palette.mode === "dark" ? 0.1 : 0.2 + )}`, + }), + }, + }, + ]; + const createScoreCell = (getValue, row, field) => { + const value = getValue(); + const rawValue = field.includes("normalized") + ? row.original.evaluations[field.split(".")[1]]?.value + : value; + + const isAverageColumn = field === "model.average_score"; + const hasNoValue = value === null || value === undefined; + + return ( + + {!hasNoValue && (scoreDisplay === "normalized" || isAverageColumn) && ( + (theme.palette.mode === "light" ? 0.1 : 0.2), + transition: "width 0.3s ease", + zIndex: 0, + }} + /> + )} + + {isAverageColumn && !hasNoValue && ( + + )} + + {hasNoValue ? ( + "-" + ) : ( + <> + {isAverageColumn ? ( + <> + {value.toFixed(2)} + % + + ) : scoreDisplay === "normalized" ? ( + <> + {value.toFixed(2)} + % + + ) : ( + <>{rawValue.toFixed(2)} + )} + + )} + + + + ); + }; + + const evaluationColumns = [ + { + accessorKey: "evaluations.ifeval.normalized_score", + header: createHeaderCell("IFEval", COLUMN_TOOLTIPS.IFEVAL), + cell: ({ row, getValue }) => + createScoreCell(getValue, row, "evaluations.ifeval.normalized_score"), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "evaluations.ifeval.normalized_score" + ], + }, + { + accessorKey: "evaluations.bbh.normalized_score", + header: createHeaderCell("BBH", COLUMN_TOOLTIPS.BBH), + cell: ({ row, getValue }) => + createScoreCell(getValue, row, "evaluations.bbh.normalized_score"), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "evaluations.bbh.normalized_score" + ], + }, + { + accessorKey: "evaluations.math.normalized_score", + header: createHeaderCell("MATH", COLUMN_TOOLTIPS.MATH), + cell: ({ row, getValue }) => + createScoreCell(getValue, row, "evaluations.math.normalized_score"), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "evaluations.math.normalized_score" + ], + }, + { + accessorKey: "evaluations.gpqa.normalized_score", + header: createHeaderCell("GPQA", COLUMN_TOOLTIPS.GPQA), + cell: ({ row, getValue }) => + createScoreCell(getValue, row, "evaluations.gpqa.normalized_score"), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "evaluations.gpqa.normalized_score" + ], + }, + { + accessorKey: "evaluations.musr.normalized_score", + header: createHeaderCell("MUSR", COLUMN_TOOLTIPS.MUSR), + cell: ({ row, getValue }) => + createScoreCell(getValue, row, "evaluations.musr.normalized_score"), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "evaluations.musr.normalized_score" + ], + }, + { + accessorKey: "evaluations.mmlu_pro.normalized_score", + header: createHeaderCell("MMLU-PRO", COLUMN_TOOLTIPS.MMLU_PRO), + cell: ({ row, getValue }) => + createScoreCell(getValue, row, "evaluations.mmlu_pro.normalized_score"), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "evaluations.mmlu_pro.normalized_score" + ], + }, + ]; + + const optionalColumns = [ + { + accessorKey: "model.architecture", + header: createHeaderCell("Architecture", COLUMN_TOOLTIPS.ARCHITECTURE), + accessorFn: (row) => row.model.architecture, + cell: ({ row }) => ( + + {row.original.model.architecture || "-"} + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["model.architecture"], + }, + { + accessorKey: "model.precision", + header: createHeaderCell("Precision", COLUMN_TOOLTIPS.PRECISION), + accessorFn: (row) => row.model.precision, + cell: ({ row }) => ( + + {row.original.model.precision || "-"} + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["model.precision"], + }, + { + accessorKey: "metadata.params_billions", + header: createHeaderCell("Parameters", COLUMN_TOOLTIPS.PARAMETERS), + cell: ({ row }) => ( + + + {row.original.metadata.params_billions} + B + + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.params_billions"], + }, + { + accessorKey: "metadata.hub_license", + header: createHeaderCell("License", COLUMN_TOOLTIPS.LICENSE), + cell: ({ row }) => ( + + + {row.original.metadata.hub_license || "-"} + + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.hub_license"], + }, + { + accessorKey: "metadata.hub_hearts", + header: createHeaderCell( + "Hub ❤️", + "Number of likes received on the Hugging Face Hub" + ), + cell: ({ row }) => ( + + {row.original.metadata.hub_hearts} + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.hub_hearts"], + }, + { + accessorKey: "metadata.upload_date", + header: createHeaderCell( + "Upload Date", + "Date when the model was uploaded to the Hugging Face Hub" + ), + cell: ({ row }) => ( + + + {row.original.metadata.upload_date || "-"} + + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.upload_date"], + }, + { + accessorKey: "metadata.submission_date", + header: createHeaderCell( + "Submission Date", + "Date when the model was submitted to the leaderboard" + ), + cell: ({ row }) => ( + + + {row.original.metadata.submission_date || "-"} + + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.submission_date"], + }, + { + accessorKey: "metadata.generation", + header: createHeaderCell( + "Generation", + "The generation or version number of the model" + ), + cell: ({ row }) => ( + + {row.original.metadata.generation} + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.generation"], + }, + { + accessorKey: "metadata.base_model", + header: createHeaderCell( + "Base Model", + "The original model this model was derived from" + ), + cell: ({ row }) => ( + + + {row.original.metadata.base_model || "-"} + + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.base_model"], + }, + { + accessorKey: "metadata.co2_cost", + header: createHeaderCell("CO₂ Cost", COLUMN_TOOLTIPS.CO2_COST), + cell: ({ row }) => ( + + + {row.original.metadata.co2_cost?.toFixed(2) || "0"} + kg + + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["metadata.co2_cost"], + }, + { + accessorKey: "model.has_chat_template", + header: createHeaderCell( + "Chat Template", + "Whether this model has a chat template defined" + ), + cell: ({ row }) => ( + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["model.has_chat_template"], + }, + { + accessorKey: "features.is_not_available_on_hub", + header: createHeaderCell( + "Hub Availability", + "Whether the model is available on the Hugging Face Hub" + ), + cell: ({ row }) => ( + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "features.is_not_available_on_hub" + ], + }, + { + accessorKey: "features.is_official_provider", + header: createHeaderCell( + "Official Providers", + "Models that are officially provided and maintained by their original creators or organizations" + ), + cell: ({ row }) => ( + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES[ + "features.is_official_provider" + ], + enableSorting: true, + }, + { + accessorKey: "features.is_moe", + header: createHeaderCell( + "Mixture of Experts", + "Whether this model uses a Mixture of Experts architecture" + ), + cell: ({ row }) => , + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["features.is_moe"], + }, + { + accessorKey: "features.is_flagged", + header: createHeaderCell( + "Flag Status", + "Whether this model has been flagged for any issues" + ), + cell: ({ row }) => ( + + ), + size: TABLE_DEFAULTS.COLUMNS.COLUMN_SIZES["features.is_flagged"], + }, + ]; + + // Utiliser directement columnVisibility + const finalColumns = [ + ...baseColumns, + ...evaluationColumns.filter((col) => columnVisibility[col.accessorKey]), + ...optionalColumns + .filter((col) => columnVisibility[col.accessorKey]) + .sort((a, b) => { + // Définir l'ordre personnalisé des colonnes + const order = { + "model.architecture": 1, + "model.precision": 2, + "metadata.params_billions": 3, + "metadata.hub_license": 4, + "metadata.co2_cost": 5, + "metadata.hub_hearts": 6, + "metadata.upload_date": 7, + "metadata.submission_date": 8, + "metadata.generation": 9, + "metadata.base_model": 10, + "model.has_chat_template": 11, + "features.is_not_available_on_hub": 12, + "features.is_official_provider": 13, + "features.is_moe": 14, + "features.is_flagged": 15, + }; + return order[a.accessorKey] - order[b.accessorKey]; + }), + ]; + + return finalColumns; +}; diff --git a/frontend/src/pages/LeaderboardPage/components/Leaderboard/utils/searchUtils.js b/frontend/src/pages/LeaderboardPage/components/Leaderboard/utils/searchUtils.js new file mode 100644 index 0000000000000000000000000000000000000000..091796b7a7a3721b4d7f790f0fda75ca151a838d --- /dev/null +++ b/frontend/src/pages/LeaderboardPage/components/Leaderboard/utils/searchUtils.js @@ -0,0 +1,92 @@ +// Utility function to detect if a string looks like a regex +export const looksLikeRegex = (str) => { + const regexSpecialChars = /[\\^$.*+?()[\]{}|]/; + return regexSpecialChars.test(str); +}; + +// Function to map search fields to correct paths +const getFieldPath = (field) => { + const fieldMappings = { + precision: "model.precision", + architecture: "model.architecture", + license: "metadata.hub_license", + type: "model.type", + }; + return fieldMappings[field] || field; +}; + +// Function to extract special searches and normal text +export const parseSearchQuery = (query) => { + const specialSearches = []; + let remainingText = query; + + // Look for all @field:value patterns + const prefixRegex = /@\w+:/g; + const matches = query.match(prefixRegex) || []; + + matches.forEach((prefix) => { + const regex = new RegExp(`${prefix}([^\\s@]+)`, "g"); + remainingText = remainingText.replace(regex, (match, value) => { + const field = prefix.slice(1, -1); + specialSearches.push({ + field: getFieldPath(field), + displayField: field, + value, + }); + return ""; + }); + }); + + return { + specialSearches, + textSearch: remainingText.trim(), + }; +}; + +// Function to extract simple text search +export const extractTextSearch = (searchValue) => { + return searchValue + .split(";") + .map((query) => { + const { textSearch } = parseSearchQuery(query); + return textSearch; + }) + .filter(Boolean) + .join(";"); +}; + +// Utility function to access nested object properties +export const getValueByPath = (obj, path) => { + return path.split(".").reduce((acc, part) => acc?.[part], obj); +}; + +// Function to generate natural language description of the search +export const generateSearchDescription = (searchValue) => { + if (!searchValue) return null; + + const searchGroups = searchValue + .split(";") + .map((group) => group.trim()) + .filter(Boolean); + + return searchGroups.map((group, index) => { + const { specialSearches, textSearch } = parseSearchQuery(group); + + let parts = []; + if (textSearch) { + parts.push(textSearch); + } + + if (specialSearches.length > 0) { + const specialParts = specialSearches.map( + ({ displayField, value }) => `@${displayField}:${value}` + ); + parts = parts.concat(specialParts); + } + + return { + text: parts.join(" "), + index, + }; + }); +}; diff --git a/frontend/src/pages/QuotePage/QuotePage.js b/frontend/src/pages/QuotePage/QuotePage.js new file mode 100644 index 0000000000000000000000000000000000000000..f1aa5f3d214e42d40ac5d29d0979841b3844e3b5 --- /dev/null +++ b/frontend/src/pages/QuotePage/QuotePage.js @@ -0,0 +1,180 @@ +import React from "react"; +import { + Box, + Typography, + Paper, + IconButton, + Tooltip, + Alert, + Link, +} from "@mui/material"; +import ContentCopyIcon from "@mui/icons-material/ContentCopy"; +import PageHeader from "../../components/shared/PageHeader"; + +const citations = [ + { + title: "Braindecode: A Deep Learning Toolbox for EEG Decoding", + authors: + "Robin Tibor Schirrmeister, Jost Tobias Springenberg, Lukas Dominique Josef Fiederer, Martin Glasstetter, Katharina Eggensperger, Michael Tangermann, Frank Hutter, Wolfram Burgard, Tonio Ball", + citation: `@article{schirrmeister2017deep, + title={Deep learning with convolutional neural networks for EEG decoding and visualization}, + author={Schirrmeister, Robin Tibor and Springenberg, Jost Tobias and Fiederer, Lukas Dominique Josef and Glasstetter, Martin and Eggensperger, Katharina and Tangermann, Michael and Hutter, Frank and Burgard, Wolfram and Ball, Tonio}, + journal={Human brain mapping}, + volume={38}, + number={11}, + pages={5391--5420}, + year={2017}, + publisher={Wiley Online Library} +}`, + url: "https://onlinelibrary.wiley.com/doi/full/10.1002/hbm.23730", + type: "main", + }, +]; + +const benchmarks = [ + { + title: "ANLI: Adversarial Natural Language Inference", + authors: "Nie et al.", + citation: `@inproceedings{nie2020adversarial, + title={Adversarial NLI: A New Benchmark for Natural Language Understanding}, + author={Nie, Yixin and Williams, Adina and Dinan, Emily and Bansal, Mohit and Weston, Jason and Kiela, Douwe}, + booktitle={Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics}, + pages={4885--4901}, + year={2020} +}`, + url: "https://arxiv.org/abs/1910.14599", + }, + { + title: "LogiQA: A Challenge Dataset for Machine Reading Comprehension with Logical Reasoning", + authors: "Liu et al.", + citation: `@inproceedings{liu2020logiqa, + title={LogiQA: A Challenge Dataset for Machine Reading Comprehension with Logical Reasoning}, + author={Liu, Jian and Cui, Leyang and Liu, Hanmeng and Huang, Dandan and Wang, Yile and Zhang, Yue}, + booktitle={Proceedings of the Twenty-Ninth International Joint Conference on Artificial Intelligence}, + pages={3622--3628}, + year={2020} +}`, + url: "https://arxiv.org/abs/2007.08124", + }, +]; + +const CitationBlock = ({ citation, title, authors, url, type }) => { + const handleCopy = () => { + navigator.clipboard.writeText(citation); + }; + + return ( + + + + {title} + + + {authors} + + {url && ( + + View paper + + )} + + + + + + + + + {citation} + + + + ); +}; + +function QuotePage() { + return ( + + + + + + The citations below include both the EEG Finetune Arena and the + individual benchmarks used in our evaluation suite. + + + + + + EEG Finetune Arena + + + {citations.map((citation, index) => ( + + ))} + + + + + + Benchmarks + + + {benchmarks.map((benchmark, index) => ( + + ))} + + + + ); +} + +export default QuotePage; diff --git a/frontend/src/pages/VoteModelPage/VoteModelPage.js b/frontend/src/pages/VoteModelPage/VoteModelPage.js new file mode 100644 index 0000000000000000000000000000000000000000..cbb0d14c55e8b0521bdea1c22f2af5b4f1e5667c --- /dev/null +++ b/frontend/src/pages/VoteModelPage/VoteModelPage.js @@ -0,0 +1,896 @@ +import React, { useState, useEffect } from "react"; +import { + Box, + Typography, + Paper, + Button, + Alert, + List, + ListItem, + CircularProgress, + Chip, + Divider, + IconButton, + Stack, + Link, + useTheme, + useMediaQuery, +} from "@mui/material"; +import AccessTimeIcon from "@mui/icons-material/AccessTime"; +import PersonIcon from "@mui/icons-material/Person"; +import OpenInNewIcon from "@mui/icons-material/OpenInNew"; +import HowToVoteIcon from "@mui/icons-material/HowToVote"; +import { useAuth } from "../../hooks/useAuth"; +import PageHeader from "../../components/shared/PageHeader"; +import AuthContainer from "../../components/shared/AuthContainer"; +import { alpha } from "@mui/material/styles"; +import CheckIcon from "@mui/icons-material/Check"; + +const NoModelsToVote = () => ( + + + + No Models to Vote + + + There are currently no models waiting for votes. +
    + Check back later! +
    +
    +); + +const LOCAL_STORAGE_KEY = "pending_votes"; + +function VoteModelPage() { + const { isAuthenticated, user, loading: authLoading } = useAuth(); + const [pendingModels, setPendingModels] = useState([]); + const [loadingModels, setLoadingModels] = useState(true); + const [error, setError] = useState(null); + const [userVotes, setUserVotes] = useState(new Set()); + const [loadingVotes, setLoadingVotes] = useState({}); + const [localVotes, setLocalVotes] = useState(new Set()); + const theme = useTheme(); + const isMobile = useMediaQuery(theme.breakpoints.down("sm")); + + // Create a unique identifier for a model + const getModelUniqueId = (model) => { + return `${model.name}_${model.precision}_${model.revision}`; + }; + + const formatWaitTime = (submissionTime) => { + if (!submissionTime) return "N/A"; + + const now = new Date(); + const submitted = new Date(submissionTime); + const diffInHours = Math.floor((now - submitted) / (1000 * 60 * 60)); + + // Less than 24 hours: show in hours + if (diffInHours < 24) { + return `${diffInHours}h`; + } + + // Less than 7 days: show in days + const diffInDays = Math.floor(diffInHours / 24); + if (diffInDays < 7) { + return `${diffInDays}d`; + } + + // More than 7 days: show in weeks + const diffInWeeks = Math.floor(diffInDays / 7); + return `${diffInWeeks}w`; + }; + + const getConfigVotes = (votesData, model) => { + // Créer l'identifiant unique du modèle + const modelUniqueId = getModelUniqueId(model); + + // Compter les votes du serveur + let serverVotes = 0; + for (const [key, config] of Object.entries(votesData.votes_by_config)) { + if ( + config.precision === model.precision && + config.revision === model.revision + ) { + serverVotes = config.count; + break; + } + } + + // Ajouter les votes en attente du localStorage + const pendingVote = localVotes.has(modelUniqueId) ? 1 : 0; + + return serverVotes + pendingVote; + }; + + const sortModels = (models) => { + // Trier d'abord par nombre de votes décroissant, puis par soumission de l'utilisateur + return [...models].sort((a, b) => { + // Comparer d'abord le nombre de votes + if (b.votes !== a.votes) { + return b.votes - a.votes; + } + + // Si l'utilisateur est connecté, mettre ses modèles en priorité + if (user) { + const aIsUserModel = a.submitter === user.username; + const bIsUserModel = b.submitter === user.username; + + if (aIsUserModel && !bIsUserModel) return -1; + if (!aIsUserModel && bIsUserModel) return 1; + } + + // Si égalité, trier par date de soumission (le plus récent d'abord) + return new Date(b.submission_time) - new Date(a.submission_time); + }); + }; + + // Add this function to handle localStorage + const updateLocalVotes = (modelUniqueId, action = "add") => { + const storedVotes = JSON.parse( + localStorage.getItem(LOCAL_STORAGE_KEY) || "[]" + ); + if (action === "add") { + if (!storedVotes.includes(modelUniqueId)) { + storedVotes.push(modelUniqueId); + } + } else { + const index = storedVotes.indexOf(modelUniqueId); + if (index > -1) { + storedVotes.splice(index, 1); + } + } + localStorage.setItem(LOCAL_STORAGE_KEY, JSON.stringify(storedVotes)); + setLocalVotes(new Set(storedVotes)); + }; + + useEffect(() => { + const fetchData = async () => { + try { + // Ne pas afficher le loading si on a déjà des données + if (pendingModels.length === 0) { + setLoadingModels(true); + } + setError(null); + + // Charger d'abord les votes en attente du localStorage + const storedVotes = JSON.parse( + localStorage.getItem(LOCAL_STORAGE_KEY) || "[]" + ); + const localVotesSet = new Set(storedVotes); + + // Préparer toutes les requêtes en parallèle + const [pendingModelsResponse, userVotesResponse] = await Promise.all([ + fetch("/api/models/pending"), + isAuthenticated && user + ? fetch(`/api/votes/user/${user.username}`) + : Promise.resolve(null), + ]); + + if (!pendingModelsResponse.ok) { + throw new Error("Failed to fetch pending models"); + } + + const modelsData = await pendingModelsResponse.json(); + const votedModels = new Set(); + + // Traiter les votes de l'utilisateur si connecté + if (userVotesResponse && userVotesResponse.ok) { + const votesData = await userVotesResponse.json(); + const userVotes = Array.isArray(votesData) ? votesData : []; + + userVotes.forEach((vote) => { + const uniqueId = `${vote.model}_${vote.precision || "unknown"}_${ + vote.revision || "main" + }`; + votedModels.add(uniqueId); + if (localVotesSet.has(uniqueId)) { + localVotesSet.delete(uniqueId); + updateLocalVotes(uniqueId, "remove"); + } + }); + } + + // Préparer et exécuter toutes les requêtes de votes en une seule fois + const modelVotesResponses = await Promise.all( + modelsData.map((model) => { + const [provider, modelName] = model.name.split("/"); + return fetch(`/api/votes/model/${provider}/${modelName}`) + .then((response) => + response.ok + ? response.json() + : { total_votes: 0, votes_by_config: {} } + ) + .catch(() => ({ total_votes: 0, votes_by_config: {} })); + }) + ); + + // Construire les modèles avec toutes les données + const modelsWithVotes = modelsData.map((model, index) => { + const votesData = modelVotesResponses[index]; + const modelUniqueId = getModelUniqueId(model); + const isVotedByUser = + votedModels.has(modelUniqueId) || localVotesSet.has(modelUniqueId); + + return { + ...model, + votes: getConfigVotes( + { + ...votesData, + votes_by_config: votesData.votes_by_config || {}, + }, + model + ), + votes_by_config: votesData.votes_by_config || {}, + wait_time: formatWaitTime(model.submission_time), + hasVoted: isVotedByUser, + }; + }); + + // Mettre à jour tous les états en une seule fois + const sortedModels = sortModels(modelsWithVotes); + + // Batch updates + const updates = () => { + setPendingModels(sortedModels); + setUserVotes(votedModels); + setLocalVotes(localVotesSet); + setLoadingModels(false); + }; + + updates(); + } catch (err) { + console.error("Error fetching data:", err); + setError(err.message); + setLoadingModels(false); + } + }; + + fetchData(); + }, [isAuthenticated, user]); + + // Modify the handleVote function + const handleVote = async (model) => { + if (!isAuthenticated) return; + + const modelUniqueId = getModelUniqueId(model); + + try { + setError(null); + setLoadingVotes((prev) => ({ ...prev, [modelUniqueId]: true })); + + // Add to localStorage immediately + updateLocalVotes(modelUniqueId, "add"); + + // Encode model name for URL + const encodedModelName = encodeURIComponent(model.name); + + const response = await fetch( + `/api/votes/${encodedModelName}?vote_type=up&user_id=${user.username}`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + precision: model.precision, + revision: model.revision, + }), + } + ); + + if (!response.ok) { + // If the request fails, remove from localStorage + updateLocalVotes(modelUniqueId, "remove"); + throw new Error("Failed to submit vote"); + } + + // Refresh votes for this model with cache bypass + const [provider, modelName] = model.name.split("/"); + const timestamp = Date.now(); + const votesResponse = await fetch( + `/api/votes/model/${provider}/${modelName}?nocache=${timestamp}` + ); + + if (!votesResponse.ok) { + throw new Error("Failed to fetch updated votes"); + } + + const votesData = await votesResponse.json(); + console.log(`Updated votes for ${model.name}:`, votesData); // Debug log + + // Update model and resort the list + setPendingModels((models) => { + const updatedModels = models.map((m) => + getModelUniqueId(m) === getModelUniqueId(model) + ? { + ...m, + votes: getConfigVotes(votesData, m), + votes_by_config: votesData.votes_by_config || {}, + hasVoted: true, + } + : m + ); + const sortedModels = sortModels(updatedModels); + console.log("Updated and sorted models:", sortedModels); // Debug log + return sortedModels; + }); + + // Update user votes with unique ID + setUserVotes((prev) => new Set([...prev, getModelUniqueId(model)])); + } catch (err) { + console.error("Error voting:", err); + setError(err.message); + } finally { + // Clear loading state for this model + setLoadingVotes((prev) => ({ + ...prev, + [modelUniqueId]: false, + })); + } + }; + + // Modify the rendering logic to consider both server and local votes + // Inside the map function where you render models + const isVoted = (model) => { + const modelUniqueId = getModelUniqueId(model); + return userVotes.has(modelUniqueId) || localVotes.has(modelUniqueId); + }; + + if (authLoading || (loadingModels && pendingModels.length === 0)) { + return ( + + + + ); + } + + return ( + + + Help us prioritize which + models to evaluate next + + } + /> + + {error && ( + + {error} + + )} + + {/* Auth Status */} + {/* + {isAuthenticated ? ( + + + + + Connected as {user?.username} + + + + + + + ) : ( + + + Login to Vote + + + You need to be logged in with your Hugging Face account to vote + for models + + + + )} + */} + + + {/* Models List */} + + {/* Header - Always visible */} + + theme.palette.mode === "dark" + ? alpha(theme.palette.divider, 0.1) + : "grey.200", + bgcolor: (theme) => + theme.palette.mode === "dark" + ? alpha(theme.palette.background.paper, 0.5) + : "grey.50", + }} + > + + Models Pending Evaluation + + + + {/* Table Header */} + + + + Model + + + + + Votes + + + + + Priority + + + + + {/* Content */} + {loadingModels ? ( + + + + ) : pendingModels.length === 0 && !loadingModels ? ( + + ) : ( + + {pendingModels.map((model, index) => { + const isTopThree = index < 3; + return ( + + {index > 0 && } + + {/* Left side - Model info */} + + + {/* Model name and link */} + + + + {model.name} + + + + + + + + + + + {/* Metadata row */} + + + + + {model.wait_time} + + + + + + {model.submitter} + + + + + + + {/* Vote Column */} + + + + + + + + + + {model.votes > 999 ? "999" : model.votes} + + + + votes + + + + + + + {/* Priority Column */} + + + {isTopThree && ( + + HIGH + + )} + + #{index + 1} + + + } + size="medium" + variant={isTopThree ? "filled" : "outlined"} + sx={{ + height: 36, + minWidth: "100px", + bgcolor: isTopThree + ? (theme) => alpha(theme.palette.primary.main, 0.1) + : "transparent", + borderColor: isTopThree ? "primary.main" : "grey.300", + borderWidth: 2, + "& .MuiChip-label": { + px: 2, + fontSize: "0.95rem", + }, + }} + /> + + + + ); + })} + + )} + + + ); +} + +export default VoteModelPage; diff --git a/publications/README.md b/publications/README.md deleted file mode 100644 index 512c1df4f0a9877a49da0271436b39e23ab4958d..0000000000000000000000000000000000000000 --- a/publications/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Publications - -This directory contains benchmark publications that were created using the `EEG-finetune-arena`. - -To cite these articles, please use the following references: -```bib -@inproceedings{ - title = {Placeholder} -} -``` \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 3b4737924b5a7d81c962a4e28b66ac6cdcc3b004..0000000000000000000000000000000000000000 --- a/pyproject.toml +++ /dev/null @@ -1,13 +0,0 @@ -[tool.ruff] -# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. -select = ["E", "F"] -ignore = ["E501"] # line too long (black is taking care of this) -line-length = 119 -fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"] - -[tool.isort] -profile = "black" -line_length = 119 - -[tool.black] -line-length = 119 diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 3cacab3e9afab55f2ce3493ac25d7a0ea5c96255..0000000000000000000000000000000000000000 --- a/requirements.txt +++ /dev/null @@ -1,16 +0,0 @@ -APScheduler -black -datasets -gradio -gradio[oauth] -gradio_leaderboard==0.0.13 -gradio_client -huggingface-hub>=0.18.0 -matplotlib -numpy -pandas -python-dateutil -tqdm -transformers -tokenizers>=0.15.0 -sentencepiece \ No newline at end of file diff --git a/src/about.py b/src/about.py deleted file mode 100644 index 09fd449bbb70f55af0188a6c94af5ad0ff625f94..0000000000000000000000000000000000000000 --- a/src/about.py +++ /dev/null @@ -1,72 +0,0 @@ -from dataclasses import dataclass -from enum import Enum - -@dataclass -class Task: - benchmark: str - metric: str - col_name: str - - -# Select your tasks here -# --------------------------------------------------- -class Tasks(Enum): - # task_key in the json file, metric_key in the json file, name to display in the leaderboard - task0 = Task("anli_r1", "acc", "ANLI") - task1 = Task("logiqa", "acc_norm", "LogiQA") - -NUM_FEWSHOT = 0 # Change with your few shot -# --------------------------------------------------- - - - -# Your leaderboard name -TITLE = """

    Demo leaderboard

    """ - -# What does your leaderboard evaluate? -INTRODUCTION_TEXT = """ -Intro text -""" - -# Which evaluations are you running? how can people reproduce what you have? -LLM_BENCHMARKS_TEXT = f""" -## How it works - -## Reproducibility -To reproduce our results, here is the commands you can run: - -""" - -EVALUATION_QUEUE_TEXT = """ -## Some good practices before submitting a model - -### 1) Make sure you can load your model and tokenizer using AutoClasses: -```python -from transformers import AutoConfig, AutoModel, AutoTokenizer -config = AutoConfig.from_pretrained("your model name", revision=revision) -model = AutoModel.from_pretrained("your model name", revision=revision) -tokenizer = AutoTokenizer.from_pretrained("your model name", revision=revision) -``` -If this step fails, follow the error messages to debug your model before submitting it. It's likely your model has been improperly uploaded. - -Note: make sure your model is public! -Note: if your model needs `use_remote_code=True`, we do not support this option yet but we are working on adding it, stay posted! - -### 2) Convert your model weights to [safetensors](https://huggingface.co/docs/safetensors/index) -It's a new format for storing weights which is safer and faster to load and use. It will also allow us to add the number of parameters of your model to the `Extended Viewer`! - -### 3) Make sure your model has an open license! -This is a leaderboard for Open LLMs, and we'd love for as many people as possible to know they can use your model 🤗 - -### 4) Fill up your model card -When we add extra information about models to the leaderboard, it will be automatically taken from the model card - -## In case of model failure -If your model is displayed in the `FAILED` category, its execution stopped. -Make sure you have followed the above steps first. -If everything is done, check you can launch the EleutherAIHarness on your model locally, using the above command without modifications (you can add `--limit` to limit the number of examples per task). -""" - -CITATION_BUTTON_LABEL = "Copy the following snippet to cite these results" -CITATION_BUTTON_TEXT = r""" -""" diff --git a/src/display/css_html_js.py b/src/display/css_html_js.py deleted file mode 100644 index 951a1fa3664c0006ecc0d164cf3f0362087cf71f..0000000000000000000000000000000000000000 --- a/src/display/css_html_js.py +++ /dev/null @@ -1,105 +0,0 @@ -custom_css = """ - -.markdown-text { - font-size: 16px !important; -} - -#models-to-add-text { - font-size: 18px !important; -} - -#citation-button span { - font-size: 16px !important; -} - -#citation-button textarea { - font-size: 16px !important; -} - -#citation-button > label > button { - margin: 6px; - transform: scale(1.3); -} - -#leaderboard-table { - margin-top: 15px -} - -#leaderboard-table-lite { - margin-top: 15px -} - -#search-bar-table-box > div:first-child { - background: none; - border: none; -} - -#search-bar { - padding: 0px; -} - -/* Limit the width of the first AutoEvalColumn so that names don't expand too much */ -#leaderboard-table td:nth-child(2), -#leaderboard-table th:nth-child(2) { - max-width: 400px; - overflow: auto; - white-space: nowrap; -} - -.tab-buttons button { - font-size: 20px; -} - -#scale-logo { - border-style: none !important; - box-shadow: none; - display: block; - margin-left: auto; - margin-right: auto; - max-width: 600px; -} - -#scale-logo .download { - display: none; -} -#filter_type{ - border: 0; - padding-left: 0; - padding-top: 0; -} -#filter_type label { - display: flex; -} -#filter_type label > span{ - margin-top: var(--spacing-lg); - margin-right: 0.5em; -} -#filter_type label > .wrap{ - width: 103px; -} -#filter_type label > .wrap .wrap-inner{ - padding: 2px; -} -#filter_type label > .wrap .wrap-inner input{ - width: 1px -} -#filter-columns-type{ - border:0; - padding:0.5; -} -#filter-columns-size{ - border:0; - padding:0.5; -} -#box-filter > .form{ - border: 0 -} -""" - -get_window_url_params = """ - function(url_params) { - const params = new URLSearchParams(window.location.search); - url_params = Object.fromEntries(params); - return url_params; - } - """ diff --git a/src/display/formatting.py b/src/display/formatting.py deleted file mode 100644 index b46d29c9dba71be80866bfe46c5a77acd0dc50ce..0000000000000000000000000000000000000000 --- a/src/display/formatting.py +++ /dev/null @@ -1,27 +0,0 @@ -def model_hyperlink(link, model_name): - return f'{model_name}' - - -def make_clickable_model(model_name): - link = f"https://huggingface.co/{model_name}" - return model_hyperlink(link, model_name) - - -def styled_error(error): - return f"

    {error}

    " - - -def styled_warning(warn): - return f"

    {warn}

    " - - -def styled_message(message): - return f"

    {message}

    " - - -def has_no_nan_values(df, columns): - return df[columns].notna().all(axis=1) - - -def has_nan_values(df, columns): - return df[columns].isna().any(axis=1) diff --git a/src/display/utils.py b/src/display/utils.py deleted file mode 100644 index 93df13e57a63ee679f863260185b34a43e4f040d..0000000000000000000000000000000000000000 --- a/src/display/utils.py +++ /dev/null @@ -1,110 +0,0 @@ -from dataclasses import dataclass, make_dataclass -from enum import Enum - -import pandas as pd - -from src.about import Tasks - -def fields(raw_class): - return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"] - - -# These classes are for user facing column names, -# to avoid having to change them all around the code -# when a modif is needed -@dataclass -class ColumnContent: - name: str - type: str - displayed_by_default: bool - hidden: bool = False - never_hidden: bool = False - -## Leaderboard columns -auto_eval_column_dict = [] -# Init -auto_eval_column_dict.append(["model_type_symbol", ColumnContent, ColumnContent("T", "str", True, never_hidden=True)]) -auto_eval_column_dict.append(["model", ColumnContent, ColumnContent("Model", "markdown", True, never_hidden=True)]) -#Scores -auto_eval_column_dict.append(["average", ColumnContent, ColumnContent("Average ⬆️", "number", True)]) -for task in Tasks: - auto_eval_column_dict.append([task.name, ColumnContent, ColumnContent(task.value.col_name, "number", True)]) -# Model information -auto_eval_column_dict.append(["model_type", ColumnContent, ColumnContent("Type", "str", False)]) -auto_eval_column_dict.append(["architecture", ColumnContent, ColumnContent("Architecture", "str", False)]) -auto_eval_column_dict.append(["weight_type", ColumnContent, ColumnContent("Weight type", "str", False, True)]) -auto_eval_column_dict.append(["precision", ColumnContent, ColumnContent("Precision", "str", False)]) -auto_eval_column_dict.append(["license", ColumnContent, ColumnContent("Hub License", "str", False)]) -auto_eval_column_dict.append(["params", ColumnContent, ColumnContent("#Params (B)", "number", False)]) -auto_eval_column_dict.append(["likes", ColumnContent, ColumnContent("Hub ❤️", "number", False)]) -auto_eval_column_dict.append(["still_on_hub", ColumnContent, ColumnContent("Available on the hub", "bool", False)]) -auto_eval_column_dict.append(["revision", ColumnContent, ColumnContent("Model sha", "str", False, False)]) - -# We use make dataclass to dynamically fill the scores from Tasks -AutoEvalColumn = make_dataclass("AutoEvalColumn", auto_eval_column_dict, frozen=True) - -## For the queue columns in the submission tab -@dataclass(frozen=True) -class EvalQueueColumn: # Queue column - model = ColumnContent("model", "markdown", True) - revision = ColumnContent("revision", "str", True) - private = ColumnContent("private", "bool", True) - precision = ColumnContent("precision", "str", True) - weight_type = ColumnContent("weight_type", "str", "Original") - status = ColumnContent("status", "str", True) - -## All the model information that we might need -@dataclass -class ModelDetails: - name: str - display_name: str = "" - symbol: str = "" # emoji - - -class ModelType(Enum): - PT = ModelDetails(name="pretrained", symbol="🟢") - FT = ModelDetails(name="fine-tuned", symbol="🔶") - IFT = ModelDetails(name="instruction-tuned", symbol="⭕") - RL = ModelDetails(name="RL-tuned", symbol="🟦") - Unknown = ModelDetails(name="", symbol="?") - - def to_str(self, separator=" "): - return f"{self.value.symbol}{separator}{self.value.name}" - - @staticmethod - def from_str(type): - if "fine-tuned" in type or "🔶" in type: - return ModelType.FT - if "pretrained" in type or "🟢" in type: - return ModelType.PT - if "RL-tuned" in type or "🟦" in type: - return ModelType.RL - if "instruction-tuned" in type or "⭕" in type: - return ModelType.IFT - return ModelType.Unknown - -class WeightType(Enum): - Adapter = ModelDetails("Adapter") - Original = ModelDetails("Original") - Delta = ModelDetails("Delta") - -class Precision(Enum): - float16 = ModelDetails("float16") - bfloat16 = ModelDetails("bfloat16") - Unknown = ModelDetails("?") - - def from_str(precision): - if precision in ["torch.float16", "float16"]: - return Precision.float16 - if precision in ["torch.bfloat16", "bfloat16"]: - return Precision.bfloat16 - return Precision.Unknown - -# Column selection -COLS = [c.name for c in fields(AutoEvalColumn) if not c.hidden] - -EVAL_COLS = [c.name for c in fields(EvalQueueColumn)] -EVAL_TYPES = [c.type for c in fields(EvalQueueColumn)] - -BENCHMARK_COLS = [t.value.col_name for t in Tasks] - diff --git a/src/envs.py b/src/envs.py deleted file mode 100644 index d761858069abf7ff590445e4770c4c3ce08b9222..0000000000000000000000000000000000000000 --- a/src/envs.py +++ /dev/null @@ -1,25 +0,0 @@ -import os - -from huggingface_hub import HfApi - -# Info to change for your repository -# ---------------------------------- -TOKEN = os.environ.get("HF_TOKEN") # A read/write token for your org - -OWNER = "demo-leaderboard-backend" # Change to your org - don't forget to create a results and request dataset, with the correct format! -# ---------------------------------- - -REPO_ID = f"{OWNER}/leaderboard" -QUEUE_REPO = f"{OWNER}/requests" -RESULTS_REPO = f"{OWNER}/results" - -# If you setup a cache later, just change HF_HOME -CACHE_PATH=os.getenv("HF_HOME", ".") - -# Local caches -EVAL_REQUESTS_PATH = os.path.join(CACHE_PATH, "eval-queue") -EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval-results") -EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk") -EVAL_RESULTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-results-bk") - -API = HfApi(token=TOKEN) diff --git a/src/leaderboard/read_evals.py b/src/leaderboard/read_evals.py deleted file mode 100644 index f90129f4768433ff52ba083f53cc501128a00430..0000000000000000000000000000000000000000 --- a/src/leaderboard/read_evals.py +++ /dev/null @@ -1,196 +0,0 @@ -import glob -import json -import math -import os -from dataclasses import dataclass - -import dateutil -import numpy as np - -from src.display.formatting import make_clickable_model -from src.display.utils import AutoEvalColumn, ModelType, Tasks, Precision, WeightType -from src.submission.check_validity import is_model_on_hub - - -@dataclass -class EvalResult: - """Represents one full evaluation. Built from a combination of the result and request file for a given run. - """ - eval_name: str # org_model_precision (uid) - full_model: str # org/model (path on hub) - org: str - model: str - revision: str # commit hash, "" if main - results: dict - precision: Precision = Precision.Unknown - model_type: ModelType = ModelType.Unknown # Pretrained, fine tuned, ... - weight_type: WeightType = WeightType.Original # Original or Adapter - architecture: str = "Unknown" - license: str = "?" - likes: int = 0 - num_params: int = 0 - date: str = "" # submission date of request file - still_on_hub: bool = False - - @classmethod - def init_from_json_file(self, json_filepath): - """Inits the result from the specific model result file""" - with open(json_filepath) as fp: - data = json.load(fp) - - config = data.get("config") - - # Precision - precision = Precision.from_str(config.get("model_dtype")) - - # Get model and org - org_and_model = config.get("model_name", config.get("model_args", None)) - org_and_model = org_and_model.split("/", 1) - - if len(org_and_model) == 1: - org = None - model = org_and_model[0] - result_key = f"{model}_{precision.value.name}" - else: - org = org_and_model[0] - model = org_and_model[1] - result_key = f"{org}_{model}_{precision.value.name}" - full_model = "/".join(org_and_model) - - still_on_hub, _, model_config = is_model_on_hub( - full_model, config.get("model_sha", "main"), trust_remote_code=True, test_tokenizer=False - ) - architecture = "?" - if model_config is not None: - architectures = getattr(model_config, "architectures", None) - if architectures: - architecture = ";".join(architectures) - - # Extract results available in this file (some results are split in several files) - results = {} - for task in Tasks: - task = task.value - - # We average all scores of a given metric (not all metrics are present in all files) - accs = np.array([v.get(task.metric, None) for k, v in data["results"].items() if task.benchmark == k]) - if accs.size == 0 or any([acc is None for acc in accs]): - continue - - mean_acc = np.mean(accs) * 100.0 - results[task.benchmark] = mean_acc - - return self( - eval_name=result_key, - full_model=full_model, - org=org, - model=model, - results=results, - precision=precision, - revision= config.get("model_sha", ""), - still_on_hub=still_on_hub, - architecture=architecture - ) - - def update_with_request_file(self, requests_path): - """Finds the relevant request file for the current model and updates info with it""" - request_file = get_request_file_for_model(requests_path, self.full_model, self.precision.value.name) - - try: - with open(request_file, "r") as f: - request = json.load(f) - self.model_type = ModelType.from_str(request.get("model_type", "")) - self.weight_type = WeightType[request.get("weight_type", "Original")] - self.license = request.get("license", "?") - self.likes = request.get("likes", 0) - self.num_params = request.get("params", 0) - self.date = request.get("submitted_time", "") - except Exception: - print(f"Could not find request file for {self.org}/{self.model} with precision {self.precision.value.name}") - - def to_dict(self): - """Converts the Eval Result to a dict compatible with our dataframe display""" - average = sum([v for v in self.results.values() if v is not None]) / len(Tasks) - data_dict = { - "eval_name": self.eval_name, # not a column, just a save name, - AutoEvalColumn.precision.name: self.precision.value.name, - AutoEvalColumn.model_type.name: self.model_type.value.name, - AutoEvalColumn.model_type_symbol.name: self.model_type.value.symbol, - AutoEvalColumn.weight_type.name: self.weight_type.value.name, - AutoEvalColumn.architecture.name: self.architecture, - AutoEvalColumn.model.name: make_clickable_model(self.full_model), - AutoEvalColumn.revision.name: self.revision, - AutoEvalColumn.average.name: average, - AutoEvalColumn.license.name: self.license, - AutoEvalColumn.likes.name: self.likes, - AutoEvalColumn.params.name: self.num_params, - AutoEvalColumn.still_on_hub.name: self.still_on_hub, - } - - for task in Tasks: - data_dict[task.value.col_name] = self.results[task.value.benchmark] - - return data_dict - - -def get_request_file_for_model(requests_path, model_name, precision): - """Selects the correct request file for a given model. Only keeps runs tagged as FINISHED""" - request_files = os.path.join( - requests_path, - f"{model_name}_eval_request_*.json", - ) - request_files = glob.glob(request_files) - - # Select correct request file (precision) - request_file = "" - request_files = sorted(request_files, reverse=True) - for tmp_request_file in request_files: - with open(tmp_request_file, "r") as f: - req_content = json.load(f) - if ( - req_content["status"] in ["FINISHED"] - and req_content["precision"] == precision.split(".")[-1] - ): - request_file = tmp_request_file - return request_file - - -def get_raw_eval_results(results_path: str, requests_path: str) -> list[EvalResult]: - """From the path of the results folder root, extract all needed info for results""" - model_result_filepaths = [] - - for root, _, files in os.walk(results_path): - # We should only have json files in model results - if len(files) == 0 or any([not f.endswith(".json") for f in files]): - continue - - # Sort the files by date - try: - files.sort(key=lambda x: x.removesuffix(".json").removeprefix("results_")[:-7]) - except dateutil.parser._parser.ParserError: - files = [files[-1]] - - for file in files: - model_result_filepaths.append(os.path.join(root, file)) - - eval_results = {} - for model_result_filepath in model_result_filepaths: - # Creation of result - eval_result = EvalResult.init_from_json_file(model_result_filepath) - eval_result.update_with_request_file(requests_path) - - # Store results of same eval together - eval_name = eval_result.eval_name - if eval_name in eval_results.keys(): - eval_results[eval_name].results.update({k: v for k, v in eval_result.results.items() if v is not None}) - else: - eval_results[eval_name] = eval_result - - results = [] - for v in eval_results.values(): - try: - v.to_dict() # we test if the dict version is complete - results.append(v) - except KeyError: # not all eval values present - continue - - return results diff --git a/src/populate.py b/src/populate.py deleted file mode 100644 index 07f0c3af11cdb57f07ddfd29e654d76dd00f5c1b..0000000000000000000000000000000000000000 --- a/src/populate.py +++ /dev/null @@ -1,58 +0,0 @@ -import json -import os - -import pandas as pd - -from src.display.formatting import has_no_nan_values, make_clickable_model -from src.display.utils import AutoEvalColumn, EvalQueueColumn -from src.leaderboard.read_evals import get_raw_eval_results - - -def get_leaderboard_df(results_path: str, requests_path: str, cols: list, benchmark_cols: list) -> pd.DataFrame: - """Creates a dataframe from all the individual experiment results""" - raw_data = get_raw_eval_results(results_path, requests_path) - all_data_json = [v.to_dict() for v in raw_data] - - df = pd.DataFrame.from_records(all_data_json) - df = df.sort_values(by=[AutoEvalColumn.average.name], ascending=False) - df = df[cols].round(decimals=2) - - # filter out if any of the benchmarks have not been produced - df = df[has_no_nan_values(df, benchmark_cols)] - return df - - -def get_evaluation_queue_df(save_path: str, cols: list) -> list[pd.DataFrame]: - """Creates the different dataframes for the evaluation queues requestes""" - entries = [entry for entry in os.listdir(save_path) if not entry.startswith(".")] - all_evals = [] - - for entry in entries: - if ".json" in entry: - file_path = os.path.join(save_path, entry) - with open(file_path) as fp: - data = json.load(fp) - - data[EvalQueueColumn.model.name] = make_clickable_model(data["model"]) - data[EvalQueueColumn.revision.name] = data.get("revision", "main") - - all_evals.append(data) - elif ".md" not in entry: - # this is a folder - sub_entries = [e for e in os.listdir(f"{save_path}/{entry}") if os.path.isfile(e) and not e.startswith(".")] - for sub_entry in sub_entries: - file_path = os.path.join(save_path, entry, sub_entry) - with open(file_path) as fp: - data = json.load(fp) - - data[EvalQueueColumn.model.name] = make_clickable_model(data["model"]) - data[EvalQueueColumn.revision.name] = data.get("revision", "main") - all_evals.append(data) - - pending_list = [e for e in all_evals if e["status"] in ["PENDING", "RERUN"]] - running_list = [e for e in all_evals if e["status"] == "RUNNING"] - finished_list = [e for e in all_evals if e["status"].startswith("FINISHED") or e["status"] == "PENDING_NEW_EVAL"] - df_pending = pd.DataFrame.from_records(pending_list, columns=cols) - df_running = pd.DataFrame.from_records(running_list, columns=cols) - df_finished = pd.DataFrame.from_records(finished_list, columns=cols) - return df_finished[cols], df_running[cols], df_pending[cols] diff --git a/src/submission/check_validity.py b/src/submission/check_validity.py deleted file mode 100644 index d06ee4c444178e369214fbf33d82e81c6f087850..0000000000000000000000000000000000000000 --- a/src/submission/check_validity.py +++ /dev/null @@ -1,99 +0,0 @@ -import json -import os -import re -from collections import defaultdict -from datetime import datetime, timedelta, timezone - -import huggingface_hub -from huggingface_hub import ModelCard -from huggingface_hub.hf_api import ModelInfo -from transformers import AutoConfig -from transformers.models.auto.tokenization_auto import AutoTokenizer - -def check_model_card(repo_id: str) -> tuple[bool, str]: - """Checks if the model card and license exist and have been filled""" - try: - card = ModelCard.load(repo_id) - except huggingface_hub.utils.EntryNotFoundError: - return False, "Please add a model card to your model to explain how you trained/fine-tuned it." - - # Enforce license metadata - if card.data.license is None: - if not ("license_name" in card.data and "license_link" in card.data): - return False, ( - "License not found. Please add a license to your model card using the `license` metadata or a" - " `license_name`/`license_link` pair." - ) - - # Enforce card content - if len(card.text) < 200: - return False, "Please add a description to your model card, it is too short." - - return True, "" - -def is_model_on_hub(model_name: str, revision: str, token: str = None, trust_remote_code=False, test_tokenizer=False) -> tuple[bool, str]: - """Checks if the model model_name is on the hub, and whether it (and its tokenizer) can be loaded with AutoClasses.""" - try: - config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=trust_remote_code, token=token) - if test_tokenizer: - try: - tk = AutoTokenizer.from_pretrained(model_name, revision=revision, trust_remote_code=trust_remote_code, token=token) - except ValueError as e: - return ( - False, - f"uses a tokenizer which is not in a transformers release: {e}", - None - ) - except Exception as e: - return (False, "'s tokenizer cannot be loaded. Is your tokenizer class in a stable transformers release, and correctly configured?", None) - return True, None, config - - except ValueError: - return ( - False, - "needs to be launched with `trust_remote_code=True`. For safety reason, we do not allow these models to be automatically submitted to the leaderboard.", - None - ) - - except Exception as e: - return False, "was not found on hub!", None - - -def get_model_size(model_info: ModelInfo, precision: str): - """Gets the model size from the configuration, or the model name if the configuration does not contain the information.""" - try: - model_size = round(model_info.safetensors["total"] / 1e9, 3) - except (AttributeError, TypeError): - return 0 # Unknown model sizes are indicated as 0, see NUMERIC_INTERVALS in app.py - - size_factor = 8 if (precision == "GPTQ" or "gptq" in model_info.modelId.lower()) else 1 - model_size = size_factor * model_size - return model_size - -def get_model_arch(model_info: ModelInfo): - """Gets the model architecture from the configuration""" - return model_info.config.get("architectures", "Unknown") - -def already_submitted_models(requested_models_dir: str) -> set[str]: - """Gather a list of already submitted models to avoid duplicates""" - depth = 1 - file_names = [] - users_to_submission_dates = defaultdict(list) - - for root, _, files in os.walk(requested_models_dir): - current_depth = root.count(os.sep) - requested_models_dir.count(os.sep) - if current_depth == depth: - for file in files: - if not file.endswith(".json"): - continue - with open(os.path.join(root, file), "r") as f: - info = json.load(f) - file_names.append(f"{info['model']}_{info['revision']}_{info['precision']}") - - # Select organisation - if info["model"].count("/") == 0 or "submitted_time" not in info: - continue - organisation, _ = info["model"].split("/") - users_to_submission_dates[organisation].append(info["submitted_time"]) - - return set(file_names), users_to_submission_dates diff --git a/src/submission/submit.py b/src/submission/submit.py deleted file mode 100644 index cac6ea48e803a0af42dabe5226191c769dbec71d..0000000000000000000000000000000000000000 --- a/src/submission/submit.py +++ /dev/null @@ -1,119 +0,0 @@ -import json -import os -from datetime import datetime, timezone - -from src.display.formatting import styled_error, styled_message, styled_warning -from src.envs import API, EVAL_REQUESTS_PATH, TOKEN, QUEUE_REPO -from src.submission.check_validity import ( - already_submitted_models, - check_model_card, - get_model_size, - is_model_on_hub, -) - -REQUESTED_MODELS = None -USERS_TO_SUBMISSION_DATES = None - -def add_new_eval( - model: str, - base_model: str, - revision: str, - precision: str, - weight_type: str, - model_type: str, -): - global REQUESTED_MODELS - global USERS_TO_SUBMISSION_DATES - if not REQUESTED_MODELS: - REQUESTED_MODELS, USERS_TO_SUBMISSION_DATES = already_submitted_models(EVAL_REQUESTS_PATH) - - user_name = "" - model_path = model - if "/" in model: - user_name = model.split("/")[0] - model_path = model.split("/")[1] - - precision = precision.split(" ")[0] - current_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") - - if model_type is None or model_type == "": - return styled_error("Please select a model type.") - - # Does the model actually exist? - if revision == "": - revision = "main" - - # Is the model on the hub? - if weight_type in ["Delta", "Adapter"]: - base_model_on_hub, error, _ = is_model_on_hub(model_name=base_model, revision=revision, token=TOKEN, test_tokenizer=True) - if not base_model_on_hub: - return styled_error(f'Base model "{base_model}" {error}') - - if not weight_type == "Adapter": - model_on_hub, error, _ = is_model_on_hub(model_name=model, revision=revision, token=TOKEN, test_tokenizer=True) - if not model_on_hub: - return styled_error(f'Model "{model}" {error}') - - # Is the model info correctly filled? - try: - model_info = API.model_info(repo_id=model, revision=revision) - except Exception: - return styled_error("Could not get your model information. Please fill it up properly.") - - model_size = get_model_size(model_info=model_info, precision=precision) - - # Were the model card and license filled? - try: - license = model_info.cardData["license"] - except Exception: - return styled_error("Please select a license for your model") - - modelcard_OK, error_msg = check_model_card(model) - if not modelcard_OK: - return styled_error(error_msg) - - # Seems good, creating the eval - print("Adding new eval") - - eval_entry = { - "model": model, - "base_model": base_model, - "revision": revision, - "precision": precision, - "weight_type": weight_type, - "status": "PENDING", - "submitted_time": current_time, - "model_type": model_type, - "likes": model_info.likes, - "params": model_size, - "license": license, - "private": False, - } - - # Check for duplicate submission - if f"{model}_{revision}_{precision}" in REQUESTED_MODELS: - return styled_warning("This model has been already submitted.") - - print("Creating eval file") - OUT_DIR = f"{EVAL_REQUESTS_PATH}/{user_name}" - os.makedirs(OUT_DIR, exist_ok=True) - out_path = f"{OUT_DIR}/{model_path}_eval_request_False_{precision}_{weight_type}.json" - - with open(out_path, "w") as f: - f.write(json.dumps(eval_entry)) - - print("Uploading eval file") - API.upload_file( - path_or_fileobj=out_path, - path_in_repo=out_path.split("eval-queue/")[1], - repo_id=QUEUE_REPO, - repo_type="dataset", - commit_message=f"Add {model} to eval queue", - ) - - # Remove the local file - os.remove(out_path) - - return styled_message( - "Your request has been submitted to the evaluation queue!\nPlease wait for up to an hour for the model to show in the PENDING list." - )