| import json |
| import os |
| import time |
| import glob |
| from dataclasses import dataclass |
| from typing import List, Dict, Any, Optional |
|
|
| import pandas as pd |
| from huggingface_hub import snapshot_download |
| from loguru import logger |
|
|
| from competitions.enums import ErrorMessage |
| from competitions.enums import SubmissionStatus |
| from competitions.info import CompetitionInfo |
| from competitions.utils import ( |
| submission_api, |
| image_upload_api, |
| batch_job_api, |
| error_log_api, |
| DIRECT_UPLOAD_URL_REFRESH_THRESHOLD_SECONDS, |
| ) |
|
|
|
|
| @dataclass |
| class JobRunner: |
| competition_id: str |
| token: str |
| output_path: str |
|
|
| def __post_init__(self): |
| self.competition_info = CompetitionInfo(competition_id=self.competition_id, autotrain_token=self.token) |
| self.competition_id = self.competition_info.competition_id |
| self.competition_type = self.competition_info.competition_type |
| self.metric = self.competition_info.metric |
| self.submission_id_col = self.competition_info.submission_id_col |
| self.submission_cols = self.competition_info.submission_cols |
| self.submission_rows = self.competition_info.submission_rows |
| self.time_limit = self.competition_info.time_limit |
| self.dataset = self.competition_info.dataset |
| self.submission_filenames = self.competition_info.submission_filenames |
|
|
| def _get_all_submissions(self) -> List[Dict[str, Any]]: |
| submission_jsons = snapshot_download( |
| repo_id=self.competition_id, |
| allow_patterns="submission_info/*.json", |
| token=self.token, |
| repo_type="dataset", |
| ) |
| submission_jsons = glob.glob(os.path.join(submission_jsons, "submission_info/*.json")) |
| all_submissions = [] |
| for _json_path in submission_jsons: |
| with open(_json_path, "r", encoding="utf-8") as f: |
| _json = json.load(f) |
| team_id = _json["id"] |
| for sub in _json["submissions"]: |
| all_submissions.append( |
| { |
| "team_id": team_id, |
| "submission_id": sub["submission_id"], |
| "datetime": sub["datetime"], |
| "status": sub["status"], |
| "submission_repo": sub["submission_repo"], |
| "hardware": sub["hardware"], |
| "batch_job_id": sub.get("batch_job_id", ""), |
| "presigned_url": sub.get("presigned_url", ""), |
| "expired_at": sub.get("expired_at", ""), |
| } |
| ) |
| return all_submissions |
|
|
| def _get_subs(self, submissions: List[Dict[str, Any]], status: SubmissionStatus) -> Optional[pd.DataFrame]: |
| filtered_submissions = [] |
| for sub in submissions: |
| if sub["status"] == status.value: |
| filtered_submissions.append(sub) |
| if len(filtered_submissions) == 0: |
| return None |
| logger.info(f"Found {len(filtered_submissions)} {status.name.lower()} submissions.") |
| filtered_submissions = pd.DataFrame(filtered_submissions) |
| filtered_submissions["datetime"] = pd.to_datetime(filtered_submissions["datetime"]) |
| filtered_submissions = filtered_submissions.sort_values("datetime") |
| filtered_submissions = filtered_submissions.reset_index(drop=True) |
| return filtered_submissions |
|
|
| def _process_pending_submission(self, submission: Optional[pd.DataFrame]): |
| if submission is None: |
| return |
| first_pending_sub = submission.iloc[0] |
| team_id = first_pending_sub['team_id'] |
| submission_id = first_pending_sub['submission_id'] |
| submission_api.update_submission_status( |
| team_id=team_id, |
| submission_id=submission_id, |
| status=SubmissionStatus.PROCESSING.value, |
| ) |
| try: |
| job_id = batch_job_api.create_job(team_id=team_id, submission_id=submission_id) |
| submission_api.update_submission_data( |
| team_id=team_id, |
| submission_id=submission_id, |
| data={ |
| "status": SubmissionStatus.PROCESSING.value, |
| "batch_job_id": job_id, |
| } |
| ) |
| except Exception as e: |
| logger.error( |
| f"Failed to process {submission_id}: {e}" |
| ) |
| submission_api.update_submission_data( |
| team_id=team_id, |
| submission_id=submission_id, |
| data={ |
| "status": SubmissionStatus.FAILED.value, |
| "error_message": ErrorMessage.FAILED_CREATE_JOB.value |
| } |
| ) |
|
|
| def _process_wait_image_upload_submission(self, submission: Optional[pd.DataFrame]): |
| if submission is None: |
| return |
| for _, sub in submission.iterrows(): |
| team_id = sub['team_id'] |
| submission_id = sub['submission_id'] |
| uploaded = image_upload_api.is_presigned_s3_upload_completed(team_id=team_id, submission_id=submission_id) |
| if uploaded: |
| submission_api.update_submission_status( |
| team_id=team_id, |
| submission_id=submission_id, |
| status=SubmissionStatus.PENDING.value, |
| ) |
| return |
|
|
| expired_at = sub.get("expired_at") |
| expired_at = ( |
| pd.Timestamp(expired_at) |
| if pd.notna(expired_at) and str(expired_at).strip() |
| else pd.NaT |
| ) |
| if pd.notna(expired_at) and expired_at.tz is None: |
| expired_at = expired_at.tz_localize("UTC") |
|
|
| if pd.isna(expired_at) or ( |
| pd.Timestamp.now(tz=expired_at.tz) - expired_at |
| ).total_seconds() >= -DIRECT_UPLOAD_URL_REFRESH_THRESHOLD_SECONDS: |
| refreshed_upload = image_upload_api.create_presigned_s3_upload( |
| team_id=team_id, |
| submission_id=submission_id, |
| ) |
| submission_api.update_submission_data( |
| team_id=team_id, |
| submission_id=submission_id, |
| data=refreshed_upload, |
| ) |
| logger.info( |
| f"Refreshed presigned upload URL for submission {submission_id} of team {team_id}." |
| ) |
| logger.info(f"Submission {submission_id} is still waiting for image upload.") |
|
|
| def _process_processing_submission(self, submission: Optional[pd.DataFrame]): |
| if submission is None: |
| return |
| for _, sub in submission.iterrows(): |
| team_id = sub['team_id'] |
| submission_id = sub['submission_id'] |
| job_state = batch_job_api.get_job_state(sub.get("batch_job_id", "")) |
| if job_state is None: |
| logger.error(f"Failed to get job state for submission {submission_id} of team {team_id}.") |
| continue |
| job_status = job_state["status"] |
| if job_status in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"]: |
| logger.info( |
| f"Submission {submission_id} of team {team_id} is still being processed. " |
| f"Current job state: {job_status}." |
| ) |
| continue |
| if job_status in ["FAILED"]: |
| error_log = job_state.get("error_log") or "job failed" |
| logger.error( |
| f"Batch job for submission {submission_id} of team {team_id} has failed. " |
| ) |
| error_log_api.save_error_log(submission_id, error_log) |
| submission_api.update_submission_data( |
| team_id=team_id, |
| submission_id=submission_id, |
| data={ |
| "status": SubmissionStatus.FAILED.value, |
| "error_message": ErrorMessage.RUNTIME_ERROR.value |
| } |
| ) |
| continue |
| if job_status in ["SUCCEEDED"]: |
| logger.info(f"Batch job for submission {submission_id} of team {team_id} has succeeded.") |
| submission_api.update_submission_data( |
| team_id=team_id, |
| submission_id=submission_id, |
| data={ |
| "status": SubmissionStatus.SUCCESS.value, |
| "public_score": {}, |
| "score": {}, |
| } |
| ) |
|
|
| def run(self): |
| while True: |
| time.sleep(5) |
|
|
| all_submissions = self._get_all_submissions() |
|
|
| wait_image_upload_submissions = self._get_subs(all_submissions, SubmissionStatus.WAIT_IMAGE_UPLOAD) |
| self._process_wait_image_upload_submission(wait_image_upload_submissions) |
|
|
| pending_submissions = self._get_subs(all_submissions, SubmissionStatus.PENDING) |
| self._process_pending_submission(pending_submissions) |
|
|
| processing_submissions = self._get_subs(all_submissions, SubmissionStatus.PROCESSING) |
| self._process_processing_submission(processing_submissions) |
|
|