| import io |
| import json |
| import os |
| import shlex |
| import subprocess |
| import threading |
| import uuid |
| import base64 |
| import glob |
| from typing import Optional, Dict, Any, List |
| from datetime import datetime, timezone, timedelta |
|
|
| import boto3 |
| import requests |
| import pandas as pd |
| import jwt |
| from botocore.exceptions import ClientError |
| from fastapi import Request |
| from huggingface_hub import HfApi, hf_hub_download |
| from loguru import logger |
| from cryptography.hazmat.primitives.ciphers.aead import AESGCM |
| from cachetools import cached, TTLCache |
|
|
| from competitions.enums import SubmissionStatus |
| from competitions.params import EvalParams |
|
|
| from . import HF_URL |
|
|
|
|
| USER_TOKEN = os.environ.get("USER_TOKEN") |
| IMAGE_MAX_EXPIRATION_SECONDS = 3600 |
| DIRECT_UPLOAD_URL_EXPIRATION_SECONDS = 3 * 3600 |
| DIRECT_UPLOAD_URL_REFRESH_THRESHOLD_SECONDS = 30 * 60 |
| AWS_BATCH_LOG_GROUP = os.environ.get("AWS_BATCH_LOG_GROUP", "/aws/batch/job") |
| AWS_BATCH_ERROR_LOG_LINES = 20 |
| AWS_BATCH_ERROR_LOG_CHARS = 4000 |
| TOKEN_INFO_CACHE = TTLCache(maxsize=1024, ttl=600) |
| TOKEN_INFO_STALE_CACHE = TTLCache(maxsize=1024, ttl=86400) |
| TOKEN_INFO_CACHE_LOCK = threading.Lock() |
|
|
|
|
| class InvalidTokenError(Exception): |
| pass |
|
|
|
|
| class TemporaryTokenVerificationError(Exception): |
| pass |
|
|
|
|
| class ImageUploadApi: |
| def __init__(self, bucket: Optional[str] = None, expires_in: int = IMAGE_MAX_EXPIRATION_SECONDS): |
| self.bucket = bucket |
| self.expires_in = expires_in |
| self.s3_client = boto3.client("s3", region_name=os.environ.get("AWS_REGION", "us-east-1")) |
|
|
| def _get_bucket(self) -> str: |
| return self.bucket or os.environ["S3_IMAGE_BUCKET"] |
|
|
| def get_object_key(self, team_id: str, submission_id: str) -> str: |
| return f"{team_id}__{submission_id}" |
|
|
| def get_s3_file_path(self, team_id: str, submission_id: str) -> str: |
| return f"s3://{self._get_bucket()}/{self.get_object_key(team_id, submission_id)}" |
|
|
| def _build_expired_at(self, expires_in: int) -> str: |
| return (datetime.now(timezone.utc) + timedelta(seconds=expires_in)).isoformat() |
|
|
| def create_presigned_s3_upload( |
| self, |
| team_id: str, |
| submission_id: str, |
| expires_in: int = DIRECT_UPLOAD_URL_EXPIRATION_SECONDS, |
| ) -> Dict[str, str]: |
| return { |
| "presigned_url": self.s3_client.generate_presigned_url( |
| "put_object", |
| Params={ |
| "Bucket": self._get_bucket(), |
| "Key": self.get_object_key(team_id, submission_id), |
| }, |
| ExpiresIn=expires_in, |
| ), |
| "expired_at": self._build_expired_at(expires_in), |
| } |
|
|
| def create_multipart_upload( |
| self, |
| team_id: str, |
| submission_id: str, |
| content_type: Optional[str] = None, |
| ) -> Dict[str, str]: |
| params: Dict[str, Any] = { |
| "Bucket": self._get_bucket(), |
| "Key": self.get_object_key(team_id, submission_id), |
| } |
| if content_type: |
| params["ContentType"] = content_type |
|
|
| response = self.s3_client.create_multipart_upload(**params) |
| return { |
| "key": response["Key"], |
| "uploadId": response["UploadId"], |
| "bucket": response.get("Bucket", self._get_bucket()), |
| } |
|
|
| def generate_presigned_s3_part_upload_url( |
| self, |
| team_id: str, |
| submission_id: str, |
| upload_id: str, |
| part_number: int, |
| ) -> str: |
| return self.s3_client.generate_presigned_url( |
| "upload_part", |
| Params={ |
| "Bucket": self._get_bucket(), |
| "Key": self.get_object_key(team_id, submission_id), |
| "UploadId": upload_id, |
| "PartNumber": part_number, |
| }, |
| ExpiresIn=self.expires_in, |
| ) |
|
|
| def list_multipart_upload_parts( |
| self, |
| team_id: str, |
| submission_id: str, |
| upload_id: str, |
| ) -> List[Dict[str, Any]]: |
| bucket = self._get_bucket() |
| key = self.get_object_key(team_id, submission_id) |
| parts: List[Dict[str, Any]] = [] |
| next_part_number_marker = 0 |
|
|
| while True: |
| response = self.s3_client.list_parts( |
| Bucket=bucket, |
| Key=key, |
| UploadId=upload_id, |
| PartNumberMarker=next_part_number_marker, |
| ) |
| parts.extend(response.get("Parts", [])) |
|
|
| if not response.get("IsTruncated"): |
| return parts |
|
|
| next_part_number_marker = response.get( |
| "NextPartNumberMarker", |
| next_part_number_marker, |
| ) |
|
|
| def complete_multipart_upload( |
| self, |
| team_id: str, |
| submission_id: str, |
| upload_id: str, |
| parts: List[Dict[str, Any]], |
| ) -> str: |
| sorted_parts = sorted(parts, key=lambda part: part["PartNumber"]) |
| response = self.s3_client.complete_multipart_upload( |
| Bucket=self._get_bucket(), |
| Key=self.get_object_key(team_id, submission_id), |
| UploadId=upload_id, |
| MultipartUpload={"Parts": sorted_parts}, |
| ) |
| return response.get("Location") or self.get_s3_file_path(team_id, submission_id) |
|
|
| def abort_multipart_upload( |
| self, |
| team_id: str, |
| submission_id: str, |
| upload_id: str, |
| ) -> None: |
| self.s3_client.abort_multipart_upload( |
| Bucket=self._get_bucket(), |
| Key=self.get_object_key(team_id, submission_id), |
| UploadId=upload_id, |
| ) |
|
|
| def is_presigned_s3_upload_completed(self, team_id: str, submission_id: str) -> bool: |
| try: |
| self.s3_client.head_object(Bucket=self._get_bucket(), Key=self.get_object_key(team_id, submission_id)) |
| return True |
| except ClientError as exc: |
| if exc.response["Error"]["Code"] in {"404", "NoSuchKey"}: |
| return False |
| raise |
|
|
|
|
| image_upload_api = ImageUploadApi() |
|
|
|
|
| class BatchJobApi: |
| def __init__(self, job_definition: str, job_queue: str): |
| self.batch_client = boto3.client("batch", region_name=os.environ.get("AWS_REGION", "us-east-1")) |
| self.logs_client = boto3.client("logs", region_name=os.environ.get("AWS_REGION", "us-east-1")) |
| self.job_definition = job_definition |
| self.job_queue = job_queue |
|
|
| def _get_image_url(self, team_id: str, submission_id: str) -> str: |
| return image_upload_api.get_s3_file_path(team_id, submission_id) |
|
|
| def _get_job_name(self, team_id: str, submission_id: str) -> str: |
| return f"comp-{team_id}__{submission_id}" |
|
|
| def _fetch_cloudwatch_error_log(self, job: Dict[str, Any]) -> str: |
| log_stream_name = job.get("container", {}).get("logStreamName") |
| if not log_stream_name: |
| return "" |
|
|
| try: |
| response = self.logs_client.get_log_events( |
| logGroupName=AWS_BATCH_LOG_GROUP, |
| logStreamName=log_stream_name, |
| startFromHead=False, |
| limit=AWS_BATCH_ERROR_LOG_LINES, |
| ) |
| except ClientError as exc: |
| logger.warning(f"Failed to fetch CloudWatch logs for batch job {job.get('jobId')}: {exc}") |
| return "" |
|
|
| log_messages = [] |
| for event in response.get("events", []): |
| message = event.get("message", "").strip() |
| if message: |
| log_messages.append(message) |
|
|
| if not log_messages: |
| return "" |
|
|
| error_log = " | ".join(log_messages[-AWS_BATCH_ERROR_LOG_LINES:]) |
| return error_log[:AWS_BATCH_ERROR_LOG_CHARS] |
|
|
| def _extract_error_log(self, job: Dict[str, Any]) -> str: |
| messages: List[str] = [] |
|
|
| cloudwatch_error_log = self._fetch_cloudwatch_error_log(job) |
| if cloudwatch_error_log: |
| messages.append(cloudwatch_error_log) |
|
|
| status_reason = job.get("statusReason") |
| if status_reason: |
| messages.append(status_reason) |
|
|
| container = job.get("container", {}) |
| container_reason = container.get("reason") |
| if container_reason: |
| messages.append(container_reason) |
|
|
| attempts = job.get("attempts", []) |
| if attempts: |
| latest_attempt_reason = attempts[-1].get("container", {}).get("reason") |
| if latest_attempt_reason: |
| messages.append(latest_attempt_reason) |
|
|
| unique_messages: List[str] = [] |
| for message in messages: |
| if message not in unique_messages: |
| unique_messages.append(message) |
|
|
| if not unique_messages: |
| return "" |
| return " | ".join(unique_messages) |
|
|
| def get_job_state(self, job_id: str) -> Optional[Dict[str, Any]]: |
| response = self.batch_client.describe_jobs(jobs=[job_id]) |
| if not response["jobs"]: |
| return None |
| job = response["jobs"][0] |
| return { |
| "jobId": job["jobId"], |
| "status": job["status"], |
| "createdAt": job["createdAt"], |
| "startedAt": job.get("startedAt"), |
| "stoppedAt": job.get("stoppedAt"), |
| "error_log": self._extract_error_log(job) if job["status"] == "FAILED" else "", |
| } |
|
|
| def create_job(self, team_id: str, submission_id: str) -> str: |
| response = self.batch_client.submit_job(**{ |
| "jobName": self._get_job_name(team_id, submission_id), |
| "jobDefinition": self.job_definition, |
| "jobQueue": self.job_queue, |
| "dependsOn": [], |
| "arrayProperties": {}, |
| "parameters": {}, |
| "containerOverrides": { |
| "resourceRequirements": [], |
| "environment": [ |
| { |
| "name": "DOCKER_IMAGE_URL", |
| "value": self._get_image_url(team_id, submission_id) |
| }, |
| { |
| "name": "TASK_ID", |
| "value": self._get_job_name(team_id, submission_id) |
| } |
| ] |
| } |
| }) |
| return response["jobId"] |
|
|
|
|
| batch_job_api = BatchJobApi( |
| job_definition=os.getenv("AWS_BATCH_JOB_DEFINITION"), |
| job_queue=os.getenv("AWS_BATCH_JOB_QUEUE") |
| ) |
|
|
|
|
| def _fetch_token_information(token): |
| if token.startswith("hf_oauth"): |
| _api_url = HF_URL + "/oauth/userinfo" |
| else: |
| _api_url = HF_URL + "/api/whoami-v2" |
| headers = {} |
| cookies = {} |
| if token.startswith("hf_"): |
| headers["Authorization"] = f"Bearer {token}" |
| else: |
| cookies = {"token": token} |
| try: |
| response = requests.get( |
| _api_url, |
| headers=headers, |
| cookies=cookies, |
| timeout=3, |
| ) |
| except (requests.Timeout, ConnectionError) as err: |
| logger.error(f"Failed to request whoami-v2 - {repr(err)}") |
| raise TemporaryTokenVerificationError("Hugging Face Hub is unreachable, please try again later.") |
|
|
| if response.status_code != 200: |
| logger.error(f"Failed to request whoami-v2 - {response.status_code}") |
| if response.status_code in {429, 500, 502, 503, 504}: |
| raise TemporaryTokenVerificationError( |
| f"Hugging Face token verification is temporarily unavailable ({response.status_code})." |
| ) |
| raise InvalidTokenError("Invalid token.") |
|
|
| resp = response.json() |
| user_info = {} |
|
|
| if token.startswith("hf_oauth"): |
| user_info["id"] = resp["sub"] |
| user_info["name"] = resp["preferred_username"] |
| user_info["orgs"] = [resp["orgs"][k]["preferred_username"] for k in range(len(resp["orgs"]))] |
| else: |
| user_info["id"] = resp["id"] |
| user_info["name"] = resp["name"] |
| user_info["orgs"] = [resp["orgs"][k]["name"] for k in range(len(resp["orgs"]))] |
| return user_info |
|
|
|
|
| def token_information(token): |
| with TOKEN_INFO_CACHE_LOCK: |
| cached_user_info = TOKEN_INFO_CACHE.get(token) |
| if cached_user_info is not None: |
| return cached_user_info |
|
|
| try: |
| user_info = _fetch_token_information(token) |
| except TemporaryTokenVerificationError: |
| with TOKEN_INFO_CACHE_LOCK: |
| stale_user_info = TOKEN_INFO_STALE_CACHE.get(token) |
| if stale_user_info is not None: |
| logger.warning("Using cached token information after temporary token verification failure.") |
| with TOKEN_INFO_CACHE_LOCK: |
| TOKEN_INFO_CACHE[token] = stale_user_info |
| return stale_user_info |
| raise |
| except InvalidTokenError: |
| with TOKEN_INFO_CACHE_LOCK: |
| TOKEN_INFO_CACHE.pop(token, None) |
| TOKEN_INFO_STALE_CACHE.pop(token, None) |
| raise |
|
|
| with TOKEN_INFO_CACHE_LOCK: |
| TOKEN_INFO_CACHE[token] = user_info |
| TOKEN_INFO_STALE_CACHE[token] = user_info |
| return user_info |
|
|
|
|
| def user_authentication(request: Request): |
| auth_header = request.headers.get("Authorization") |
| bearer_token = None |
|
|
| if auth_header and auth_header.startswith("Bearer "): |
| bearer_token = auth_header.split(" ")[1] |
|
|
| if bearer_token: |
| try: |
| _ = token_information(token=bearer_token) |
| return bearer_token |
| except Exception as e: |
| logger.error(f"Failed to verify token: {e}") |
| return None |
|
|
| if USER_TOKEN is not None: |
| try: |
| _ = token_information(token=USER_TOKEN) |
| return USER_TOKEN |
| except Exception as e: |
| logger.error(f"Failed to verify token: {e}") |
| return None |
|
|
| if "oauth_info" in request.session: |
| access_token = request.session["oauth_info"].get("access_token") |
| if not access_token: |
| request.session.pop("oauth_info", None) |
| return None |
| try: |
| _ = token_information(token=access_token) |
| return access_token |
| except InvalidTokenError as e: |
| request.session.pop("oauth_info", None) |
| logger.error(f"Failed to verify token: {e}") |
| return None |
| except TemporaryTokenVerificationError as e: |
| logger.warning(f"Token verification temporarily unavailable, reusing session token: {e}") |
| return access_token |
| except Exception as e: |
| logger.error(f"Failed to verify token: {e}") |
| return None |
|
|
| return None |
|
|
|
|
| def user_authentication_dep(token, return_raw=False): |
| if token.startswith("hf_oauth"): |
| _api_url = HF_URL + "/oauth/userinfo" |
| else: |
| _api_url = HF_URL + "/api/whoami-v2" |
| headers = {} |
| cookies = {} |
| if token.startswith("hf_"): |
| headers["Authorization"] = f"Bearer {token}" |
| else: |
| cookies = {"token": token} |
| try: |
| response = requests.get( |
| _api_url, |
| headers=headers, |
| cookies=cookies, |
| timeout=3, |
| ) |
| except (requests.Timeout, ConnectionError) as err: |
| logger.error(f"Failed to request whoami-v2 - {repr(err)}") |
| raise Exception("Hugging Face Hub is unreachable, please try again later.") |
|
|
| resp = response.json() |
| if return_raw: |
| return resp |
|
|
| user_info = {} |
| if "error" in resp: |
| return resp |
| if token.startswith("hf_oauth"): |
| user_info["id"] = resp["sub"] |
| user_info["name"] = resp["preferred_username"] |
| user_info["orgs"] = [resp["orgs"][k]["preferred_username"] for k in range(len(resp["orgs"]))] |
| else: |
|
|
| user_info["id"] = resp["id"] |
| user_info["name"] = resp["name"] |
| user_info["orgs"] = [resp["orgs"][k]["name"] for k in range(len(resp["orgs"]))] |
| return user_info |
|
|
|
|
| def make_clickable_user(user_id): |
| link = "https://huggingface.co/" + user_id |
| return f'<a target="_blank" href="{link}">{user_id}</a>' |
|
|
|
|
| def run_evaluation(params, local=False, wait=False): |
| params = json.loads(params) |
| if isinstance(params, str): |
| params = json.loads(params) |
| params = EvalParams(**params) |
| if not local: |
| params.output_path = "/tmp/model" |
| params.save(output_dir=params.output_path) |
| cmd = [ |
| "python", |
| "-m", |
| "competitions.evaluate", |
| "--config", |
| os.path.join(params.output_path, "params.json"), |
| ] |
|
|
| cmd = [str(c) for c in cmd] |
| logger.info(cmd) |
| env = os.environ.copy() |
| cmd = shlex.split(" ".join(cmd)) |
| process = subprocess.Popen(cmd, env=env) |
| if wait: |
| process.wait() |
| return process.pid |
|
|
|
|
| def pause_space(params): |
| if "SPACE_ID" in os.environ: |
| if os.environ["SPACE_ID"].split("/")[-1].startswith("comp-"): |
| logger.info("Pausing space...") |
| api = HfApi(token=params.token) |
| api.pause_space(repo_id=os.environ["SPACE_ID"]) |
|
|
|
|
| def delete_space(params): |
| if "SPACE_ID" in os.environ: |
| if os.environ["SPACE_ID"].split("/")[-1].startswith("comp-"): |
| logger.info("Deleting space...") |
| api = HfApi(token=params.token) |
| api.delete_repo(repo_id=os.environ["SPACE_ID"], repo_type="space") |
|
|
|
|
| def uninstall_requirements(requirements_fname): |
| if os.path.exists(requirements_fname): |
| |
| uninstall_list = [] |
| with open(requirements_fname, "r", encoding="utf-8") as f: |
| for line in f: |
| if line.startswith("-"): |
| uninstall_list.append(line[1:]) |
|
|
| |
| with open("uninstall.txt", "w", encoding="utf-8") as f: |
| for line in uninstall_list: |
| f.write(line) |
|
|
| pipe = subprocess.Popen( |
| [ |
| "pip", |
| "uninstall", |
| "-r", |
| "uninstall.txt", |
| "-y", |
| ], |
| ) |
| pipe.wait() |
| logger.info("Requirements uninstalled.") |
| return |
|
|
|
|
| def install_requirements(requirements_fname): |
| |
| if os.path.exists(requirements_fname): |
| |
| install_list = [] |
|
|
| with open(requirements_fname, "r", encoding="utf-8") as f: |
| for line in f: |
| |
| if line.startswith("-"): |
| if not line.startswith("--"): |
| continue |
| install_list.append(line) |
|
|
| with open("install.txt", "w", encoding="utf-8") as f: |
| for line in install_list: |
| f.write(line) |
|
|
| pipe = subprocess.Popen( |
| [ |
| "pip", |
| "install", |
| "-r", |
| "install.txt", |
| ], |
| ) |
| pipe.wait() |
| logger.info("Requirements installed.") |
| return |
| logger.info("No requirements.txt found. Skipping requirements installation.") |
| return |
|
|
|
|
| def is_user_admin(user_token, competition_organization): |
| user_info = token_information(token=user_token) |
| user_orgs = user_info.get("orgs", []) |
| for org in user_orgs: |
| if org == competition_organization: |
| return True |
| return False |
|
|
|
|
| class TeamAlreadyExistsError(Exception): |
| """Custom exception for when a team already exists.""" |
| pass |
|
|
| class TeamFileApi: |
| def __init__(self, hf_token: str, competition_id: str): |
| self.hf_token = hf_token |
| self.competition_id = competition_id |
| self._lock = threading.Lock() |
|
|
| def _get_all_team_metadata(self) -> Dict[str, Dict[str, Any]]: |
| team_metadata = hf_hub_download( |
| repo_id=self.competition_id, |
| filename="teams.json", |
| token=self.hf_token, |
| repo_type="dataset", |
| ) |
|
|
| with open(team_metadata, "r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
| def _get_team_info(self, user_id: str) -> Optional[Dict[str, Any]]: |
| user_team = hf_hub_download( |
| repo_id=self.competition_id, |
| filename="user_team.json", |
| token=self.hf_token, |
| repo_type="dataset", |
| ) |
|
|
| with open(user_team, "r", encoding="utf-8") as f: |
| user_team = json.load(f) |
|
|
| if user_id not in user_team: |
| return None |
|
|
| team_id = user_team[user_id] |
|
|
| team_metadata = self._get_all_team_metadata() |
| return team_metadata[team_id] |
|
|
| def _create_team(self, user_id: str, team_name: str, other_data: Dict[str, Any]) -> str: |
| with self._lock: |
| user_team = hf_hub_download( |
| repo_id=self.competition_id, |
| filename="user_team.json", |
| token=self.hf_token, |
| repo_type="dataset", |
| ) |
| with open(user_team, "r", encoding="utf-8") as f: |
| user_team = json.load(f) |
|
|
| team_metadata = self._get_all_team_metadata() |
|
|
| |
| team_id = str(uuid.uuid4()) |
| user_team[user_id] = team_id |
|
|
| team_metadata[team_id] = { |
| "id": team_id, |
| "name": team_name, |
| "members": [user_id], |
| "leader": user_id, |
| "other_data": other_data, |
| } |
|
|
| user_team_json = json.dumps(user_team, indent=4) |
| user_team_json_bytes = user_team_json.encode("utf-8") |
| user_team_json_buffer = io.BytesIO(user_team_json_bytes) |
|
|
| team_metadata_json = json.dumps(team_metadata, indent=4) |
| team_metadata_json_bytes = team_metadata_json.encode("utf-8") |
| team_metadata_json_buffer = io.BytesIO(team_metadata_json_bytes) |
|
|
| api = HfApi(token=self.hf_token) |
| api.upload_file( |
| path_or_fileobj=user_team_json_buffer, |
| path_in_repo="user_team.json", |
| repo_id=self.competition_id, |
| repo_type="dataset", |
| ) |
| api.upload_file( |
| path_or_fileobj=team_metadata_json_buffer, |
| path_in_repo="teams.json", |
| repo_id=self.competition_id, |
| repo_type="dataset", |
| ) |
| return team_id |
|
|
| def create_team(self, user_token: str, team_name: str, other_data: Dict[str, Any]) -> str: |
| user_info = token_information(token=user_token) |
| return self._create_team(user_info["id"], team_name, other_data) |
|
|
| def update_team(self, user_token: str, team_name: str, other_data: Dict[str, Any]) -> str: |
| user_info = token_information(token=user_token) |
| user_id = user_info["id"] |
| team_info = self._get_team_info(user_id) |
| |
| with self._lock: |
| team_metadata = hf_hub_download( |
| repo_id=self.competition_id, |
| filename="teams.json", |
| token=self.hf_token, |
| repo_type="dataset", |
| ) |
| with open(team_metadata, "r", encoding="utf-8") as f: |
| team_metadata = json.load(f) |
|
|
| team_id = team_info["id"] |
| team_detail = team_metadata[team_id] |
| team_metadata[team_id] = { |
| **team_detail, |
| "name": team_name, |
| "other_data": other_data, |
| } |
|
|
| team_metadata_json = json.dumps(team_metadata, indent=4) |
| team_metadata_json_bytes = team_metadata_json.encode("utf-8") |
| team_metadata_json_buffer = io.BytesIO(team_metadata_json_bytes) |
| api = HfApi(token=self.hf_token) |
| api.upload_file( |
| path_or_fileobj=team_metadata_json_buffer, |
| path_in_repo="teams.json", |
| repo_id=self.competition_id, |
| repo_type="dataset", |
| ) |
|
|
| def get_team_info(self, user_token: str) -> Optional[Dict[str, Any]]: |
| user_info = token_information(token=user_token) |
| return self._get_team_info(user_info["id"]) |
|
|
| def get_all_team_info(self) -> Dict[str, Dict[str, Any]]: |
| return self._get_all_team_metadata() |
|
|
| def update_team_name(self, user_token, new_team_name): |
| user_info = token_information(token=user_token) |
| user_id = user_info["id"] |
| team_info = self._get_team_info(user_id) |
|
|
| with self._lock: |
| team_metadata = self._get_all_team_metadata() |
|
|
| team_metadata[team_info["id"]]["name"] = new_team_name |
| team_metadata_json = json.dumps(team_metadata, indent=4) |
| team_metadata_json_bytes = team_metadata_json.encode("utf-8") |
| team_metadata_json_buffer = io.BytesIO(team_metadata_json_bytes) |
| api = HfApi(token=self.hf_token) |
| api.upload_file( |
| path_or_fileobj=team_metadata_json_buffer, |
| path_in_repo="teams.json", |
| repo_id=self.competition_id, |
| repo_type="dataset", |
| ) |
| return new_team_name |
| |
| @cached(cache=TTLCache(maxsize=1, ttl=600)) |
| def get_team_white_list(self) -> List[str]: |
| file = hf_hub_download( |
| repo_id=self.competition_id, |
| filename="team_id_whitelist.json", |
| token=self.hf_token, |
| repo_type="dataset", |
| ) |
|
|
| with open(file, "r", encoding="utf-8") as f: |
| team_white_list = json.load(f) |
| return team_white_list |
|
|
| @cached(cache=TTLCache(maxsize=1, ttl=600)) |
| def get_team_submission_limit(self): |
| file = hf_hub_download( |
| repo_id=self.competition_id, |
| filename="team_submission_limit.json", |
| token=self.hf_token, |
| repo_type="dataset", |
| ) |
|
|
| with open(file, "r", encoding="utf-8") as f: |
| team_submission_limit = json.load(f) |
| return team_submission_limit |
|
|
|
|
| team_file_api = TeamFileApi( |
| os.environ.get("HF_TOKEN", None), |
| os.environ.get("COMPETITION_ID"), |
| ) |
|
|
|
|
| class UserTokenApi: |
| def __init__(self, hf_token: str, key_base64: str, competition_id: str): |
| self.hf_token = hf_token |
| self.key = base64.b64decode(key_base64) |
| self.competition_id = competition_id |
|
|
| def _encrypt(self, text: str) -> str: |
| aesgcm = AESGCM(self.key) |
| nonce = os.urandom(12) |
| encrypted_data = aesgcm.encrypt(nonce, text.encode(), None) |
| return base64.b64encode(nonce + encrypted_data).decode() |
|
|
| def _decrypt(self, encrypted_text: str) -> str: |
| aesgcm = AESGCM(self.key) |
| data = base64.b64decode(encrypted_text) |
| nonce = data[:12] |
| ciphertext = data[12:] |
| plaintext = aesgcm.decrypt(nonce, ciphertext, None) |
| return plaintext.decode() |
| |
| def put(self, team_id: str, user_token: str): |
| encrypted_token = self._encrypt(user_token) |
| api = HfApi(token=self.hf_token) |
| api.upload_file( |
| path_or_fileobj=io.BytesIO(encrypted_token.encode()), |
| path_in_repo=f"team_user_tokens/{team_id}", |
| repo_id=self.competition_id, |
| repo_type="dataset", |
| ) |
|
|
| def get(self, team_id: str) -> Optional[str]: |
| try: |
| user_token = hf_hub_download( |
| repo_id=self.competition_id, |
| filename=f"team_user_tokens/{team_id}", |
| token=self.hf_token, |
| repo_type="dataset", |
| ) |
| except Exception as e: |
| logger.error(f"Failed to download user token - {e}") |
| return None |
|
|
| with open(user_token, "r", encoding="utf-8") as f: |
| encrypted_token = f.read() |
|
|
| return self._decrypt(encrypted_token) |
|
|
|
|
| user_token_api = UserTokenApi( |
| os.environ.get("HF_TOKEN", None), |
| os.environ.get("USER_TOKEN_KEY_BASE64"), |
| os.environ.get("COMPETITION_ID") |
| ) |
|
|
|
|
| class SubmissionApi: |
| def __init__(self, hf_token: str, competition_id: str): |
| self.hf_token = hf_token |
| self.competition_id = competition_id |
| self.api = HfApi(token=hf_token) |
|
|
| def exists_submission_info(self, team_id: str) -> bool: |
| """ |
| Check if submission info exists for a given team ID. |
| Args: |
| team_id (str): The team ID. |
| Returns: |
| bool: True if submission info exists, False otherwise. |
| """ |
| return self.api.file_exists( |
| repo_id=self.competition_id, |
| filename=f"submission_info/{team_id}.json", |
| repo_type="dataset", |
| ) |
|
|
| def download_submission_info(self, team_id: str) -> Dict[str, Any]: |
| """ |
| Download the submission info from Hugging Face Hub. |
| Args: |
| team_id (str): The team ID. |
| Returns: |
| Dict[str, Any]: The submission info. |
| """ |
| submission_info_path = self.api.hf_hub_download( |
| repo_id=self.competition_id, |
| filename=f"submission_info/{team_id}.json", |
| repo_type="dataset", |
| ) |
| with open(submission_info_path, 'r') as f: |
| submission_info = json.load(f) |
|
|
| return submission_info |
|
|
| def upload_submission_info(self, team_id: str, user_submission_info: Dict[str, Any]): |
| user_submission_info_json = json.dumps(user_submission_info, indent=4) |
| user_submission_info_json_bytes = user_submission_info_json.encode("utf-8") |
| user_submission_info_json_buffer = io.BytesIO(user_submission_info_json_bytes) |
| self.api.upload_file( |
| path_or_fileobj=user_submission_info_json_buffer, |
| path_in_repo=f"submission_info/{team_id}.json", |
| repo_id=self.competition_id, |
| repo_type="dataset", |
| ) |
|
|
| def update_submission_data(self, team_id: str, submission_id: str, data: Dict[str, Any]): |
| user_submission_info = self.download_submission_info(team_id) |
| for submission in user_submission_info["submissions"]: |
| if submission["submission_id"] == submission_id: |
| submission.update(data) |
| break |
| self.upload_submission_info(team_id, user_submission_info) |
|
|
| def update_submission_status(self, team_id: str, submission_id: str, status: int): |
| self.update_submission_data(team_id, submission_id, {"status": status}) |
|
|
| def count_by_status(self, team_id: str, status_list: List[SubmissionStatus]) -> int: |
| user_submission_info = self.download_submission_info(team_id) |
| count = sum(1 for submission in user_submission_info["submissions"] if SubmissionStatus(submission["status"]) in status_list) |
| return count |
|
|
|
|
| submission_api = SubmissionApi( |
| hf_token=os.environ.get("HF_TOKEN", None), |
| competition_id=os.environ.get("COMPETITION_ID") |
| ) |
|
|
|
|
| class LeaderboardApi: |
| def __init__(self, hf_token: str, competition_id: str): |
| self.hf_token = hf_token |
| self.competition_id = competition_id |
| self.api = HfApi(token=hf_token) |
|
|
| @cached(cache=TTLCache(maxsize=1, ttl=300)) |
| def get_leaderboard(self) -> pd.DataFrame: |
| """ |
| Get the leaderboard for the competition. |
| Returns: |
| pd.DataFrame: The leaderboard as a DataFrame. |
| """ |
| return pd.DataFrame(columns=["team_id", "team_name", "psnr", "ssim", "lpips", "score"]) |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def _get_all_scores(self) -> List[Dict[str, Any]]: |
| team_metadata = self.api.hf_hub_download( |
| repo_id=self.competition_id, |
| filename="teams.json", |
| repo_type="dataset", |
| ) |
| with open(team_metadata, "r", encoding="utf-8") as f: |
| team_metadata = json.load(f) |
|
|
| submission_jsons = self.api.snapshot_download( |
| repo_id=self.competition_id, |
| allow_patterns="submission_info/*.json", |
| repo_type="dataset", |
| ) |
| submission_jsons = glob.glob(os.path.join(submission_jsons, "submission_info/*.json")) |
| all_scores = [] |
| for _json_path in submission_jsons: |
| with open(_json_path, "r", encoding="utf-8") as f: |
| _json = json.load(f) |
| team_id = _json["id"] |
| for sub in _json["submissions"]: |
| if sub["status"] != SubmissionStatus.SUCCESS.value: |
| continue |
| all_scores.append({ |
| "team_id": team_id, |
| "team_name": team_metadata[team_id]["name"], |
| "psnr": sub["score"]["psnr"], |
| "ssim": sub["score"]["ssim"], |
| "lpips": sub["score"]["lpips"], |
| "score": sub["score"]["score"], |
| }) |
| return all_scores |
|
|
|
|
| leaderboard_api = LeaderboardApi( |
| hf_token=os.environ.get("HF_TOKEN", None), |
| competition_id=os.environ.get("COMPETITION_ID") |
| ) |
|
|
|
|
| class ErrorLogApi: |
| def __init__(self, hf_token: str, competition_id: str, encode_key: str): |
| self.hf_token = hf_token |
| self.competition_id = competition_id |
| self.api = HfApi(token=hf_token) |
| self.encode_key = encode_key |
|
|
| def save_error_log(self, submission_id: str, content: str): |
| """Save the error log of a space to the submission.""" |
| content_buffer = io.BytesIO(content.encode()) |
| self.api.upload_file( |
| path_or_fileobj=content_buffer, |
| path_in_repo=f"error_logs/{submission_id}.txt", |
| repo_id=self.competition_id, |
| repo_type="dataset", |
| ) |
|
|
| def generate_log_token(self, submission_id: str) -> str: |
| payload = { |
| "submission_id": submission_id, |
| "exp": datetime.now(timezone.utc) + timedelta(hours=1) |
| } |
| token = jwt.encode(payload, self.encode_key, algorithm="HS256") |
| return token |
|
|
| def get_log_by_token(self, token: str) -> str: |
| try: |
| payload = jwt.decode(token, self.encode_key, algorithms=["HS256"]) |
| submission_id = payload["submission_id"] |
| except jwt.ExpiredSignatureError: |
| raise RuntimeError("Token has expired.") |
| except jwt.InvalidTokenError as e: |
| raise RuntimeError(f"Invalid token: {e}") |
|
|
| log_file_path = self.api.hf_hub_download( |
| repo_id=self.competition_id, |
| filename=f"error_logs/{submission_id}.txt", |
| repo_type="dataset", |
| ) |
| with open(log_file_path, 'r') as f: |
| file_content = f.read() |
|
|
| return file_content |
|
|
|
|
| error_log_api = ErrorLogApi( |
| hf_token=os.environ.get("HF_TOKEN", None), |
| competition_id=os.environ.get("COMPETITION_ID"), |
| encode_key=os.environ.get("ERROR_LOG_ENCODE_KEY", "key") |
| ) |
|
|