Spaces:
Sleeping
Sleeping
| import re | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| import gradio as gr | |
| import pandas as pd | |
| from huggingface_hub import CommitOperationAdd, HfApi | |
| from config import ( | |
| ACTIVATED_COL, | |
| CLINICAL_COLS, | |
| DATASET_REPO_ID, | |
| DEFAULT_LEADERBOARD_CONTEXT, | |
| HF_TOKEN, | |
| HORIZONS, | |
| METRIC_BASE_COLS, | |
| PRED_COLS, | |
| SUBMISSION_COOLDOWN_SECONDS, | |
| TECHNICAL_COLS, | |
| get_leaderboard_entry_dir, | |
| get_leaderboard_name, | |
| ) | |
| from data import ( | |
| build_submission_metrics_row, | |
| build_user_metadata_df, | |
| get_ground_truth, | |
| get_user_metadata_repo_path, | |
| get_user_metrics_repo_path, | |
| get_user_submission_history, | |
| write_user_metrics_history_file, | |
| ) | |
| from metrics import calculate_dts_error_grid, calculate_rmse, calculate_mae | |
| _MAX_WEBSITE_LEN = 200 | |
| _MAX_NOTES_LEN = 500 | |
| _JOIN_KEYS = ["id", "source_file", "date"] | |
| _EMPTY_TEXT_VALUE = "N/A" | |
| _TIMESTAMP_FORMAT = "%Y-%m-%dT%H:%M:%SZ" | |
| def validate_website(url: str) -> str | None: | |
| """Returns an error message, or None if the value is acceptable.""" | |
| url = (url or "").strip() | |
| if not url or url == "N/A": | |
| return None | |
| if len(url) > _MAX_WEBSITE_LEN: | |
| return f"⚠️ Website URL must be {_MAX_WEBSITE_LEN} characters or fewer." | |
| if re.search(r"""[<>"']""", url): | |
| return "⚠️ Website URL contains invalid characters (< > \" ')." | |
| if not re.match(r'^https?://', url, re.IGNORECASE): | |
| return "⚠️ Website URL must start with http:// or https://." | |
| return None | |
| def validate_notes(notes: str) -> str | None: | |
| """Returns an error message, or None if the value is acceptable.""" | |
| notes = (notes or "").strip() | |
| if not notes or notes == "N/A": | |
| return None | |
| if len(notes) > _MAX_NOTES_LEN: | |
| return f"⚠️ Notes must be {_MAX_NOTES_LEN} characters or fewer." | |
| if re.search(r'[<>]', notes): | |
| return "⚠️ Notes contain invalid characters (< or >)." | |
| return None | |
| def _normalize_optional_text(value: str | None) -> str: | |
| return (value or "").strip() or _EMPTY_TEXT_VALUE | |
| def _load_submission(file_path): | |
| return pd.read_parquet(file_path) | |
| def _validate_submission_metadata(website: str, notes: str) -> str | None: | |
| err = validate_website(website) | |
| if err: | |
| return err | |
| return validate_notes(notes) | |
| def _validate_submission_columns(user_df: pd.DataFrame) -> str | None: | |
| required_cols = _JOIN_KEYS + PRED_COLS | |
| missing_cols = [c for c in required_cols if c not in user_df.columns] | |
| if missing_cols: | |
| return f"⚠️ Submission is missing columns: {', '.join(missing_cols)}." | |
| return None | |
| def _merge_submission_with_ground_truth(user_df: pd.DataFrame, ground_truth_df: pd.DataFrame): | |
| n_truth = len(ground_truth_df) | |
| n_submitted = len(user_df) | |
| if n_submitted != n_truth: | |
| return None, ( | |
| f"⚠️ Submission has {n_submitted:,} rows but ground truth has {n_truth:,} rows. " | |
| "Ensure your file is derived from the predictions template without adding or removing rows." | |
| ) | |
| merged = pd.merge( | |
| ground_truth_df, | |
| user_df[_JOIN_KEYS + PRED_COLS], | |
| on=_JOIN_KEYS, | |
| how="inner", | |
| ) | |
| n_matched = len(merged) | |
| if n_matched == 0: | |
| return None, "⚠️ No matching rows found between your submission and the ground truth." | |
| if n_matched < n_truth: | |
| missing = n_truth - n_matched | |
| return None, ( | |
| f"⚠️ {missing:,} ground truth rows could not be matched in your submission. " | |
| "Check that the id, source_file, and date columns are unmodified from the template." | |
| ) | |
| return merged, None | |
| def _add_empty_horizon_scores(scores: dict[str, float], horizon: int) -> None: | |
| for metric_name in METRIC_BASE_COLS: | |
| scores[f"{metric_name}_{horizon}"] = float("nan") | |
| def _compute_scores(merged_df: pd.DataFrame) -> dict[str, float]: | |
| scores: dict[str, float] = {} | |
| for horizon in HORIZONS: | |
| true_vals = merged_df[f"target_{horizon}"].values | |
| pred_vals = merged_df[f"pred_{horizon}"].values | |
| mask = ~(pd.isna(true_vals) | pd.isna(pred_vals)) | |
| true_vals = true_vals[mask] | |
| pred_vals = pred_vals[mask] | |
| if len(true_vals) == 0: | |
| _add_empty_horizon_scores(scores, horizon) | |
| continue | |
| dts_scores = calculate_dts_error_grid(pred_vals, true_vals) | |
| for metric_name in CLINICAL_COLS: | |
| scores[f"{metric_name}_{horizon}"] = dts_scores[metric_name] | |
| scores[f"{TECHNICAL_COLS[0]}_{horizon}"] = calculate_rmse(pred_vals, true_vals) | |
| scores[f"{TECHNICAL_COLS[1]}_{horizon}"] = calculate_mae(pred_vals, true_vals) | |
| return scores | |
| def _get_remaining_cooldown_seconds(metrics_history_df: pd.DataFrame) -> int | None: | |
| if metrics_history_df.empty: | |
| return None | |
| last_ts = metrics_history_df.iloc[0]["Timestamp"] | |
| if last_ts in (_EMPTY_TEXT_VALUE, "", "nan"): | |
| return None | |
| try: | |
| last_dt = datetime.strptime(last_ts, _TIMESTAMP_FORMAT).replace(tzinfo=timezone.utc) | |
| except ValueError: | |
| return None | |
| elapsed = (datetime.now(timezone.utc) - last_dt).total_seconds() | |
| remaining = int(SUBMISSION_COOLDOWN_SECONDS - elapsed) | |
| return remaining if remaining > 0 else None | |
| def _build_submission_summary(username: str, action: str, scores: dict[str, float]) -> str: | |
| summary_lines = [ | |
| f" {horizon}min → DTS-A: {scores[f'DTS_A_ZONE_PERCENT_{horizon}']:.1f}%," | |
| f" RMSE: {scores[f'RMSE_{horizon}']:.1f}" | |
| for horizon in HORIZONS | |
| ] | |
| return f"✅ {action} {username}'s scores:\n" + "\n".join(summary_lines) | |
| def _append_submission_to_history( | |
| metrics_history_df: pd.DataFrame, | |
| scores: dict[str, float], | |
| timestamp: str, | |
| ): | |
| updated_history_df = metrics_history_df.copy() | |
| if not updated_history_df.empty: | |
| updated_history_df[ACTIVATED_COL] = False | |
| new_row_df = build_submission_metrics_row(scores, timestamp, activated=True) | |
| return pd.concat([new_row_df, updated_history_df], ignore_index=True) | |
| def _build_user_metadata_output_path(username: str, context: str) -> Path: | |
| metadata_filename = Path(get_user_metadata_repo_path(username, context)).name | |
| return Path(get_leaderboard_entry_dir(context)) / metadata_filename | |
| def _commit_user_metrics_history( | |
| username: str, | |
| metrics_history_df: pd.DataFrame, | |
| commit_message: str, | |
| context: str = DEFAULT_LEADERBOARD_CONTEXT, | |
| ): | |
| metrics_path = write_user_metrics_history_file( | |
| output_dir=Path(get_leaderboard_entry_dir(context)), | |
| username=username, | |
| metrics_history_df=metrics_history_df, | |
| ) | |
| api = HfApi() | |
| api.create_commit( | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| operations=[ | |
| CommitOperationAdd( | |
| path_in_repo=get_user_metrics_repo_path(username, context), | |
| path_or_fileobj=str(metrics_path), | |
| ), | |
| ], | |
| commit_message=commit_message, | |
| token=HF_TOKEN, | |
| ) | |
| def set_active_submission( | |
| profile: gr.OAuthProfile | None, | |
| timestamp: str | None, | |
| context: str = DEFAULT_LEADERBOARD_CONTEXT, | |
| ): | |
| """Marks one saved submission as visible on the leaderboard, or hides all.""" | |
| if profile is None: | |
| return "⚠️ Please log in with Hugging Face to manage visible submissions." | |
| username = profile.username | |
| leaderboard_name = get_leaderboard_name(context) | |
| try: | |
| metrics_history_df = get_user_submission_history(username, context) | |
| if metrics_history_df.empty: | |
| return "⚠️ No saved submissions found." | |
| metrics_history_df = metrics_history_df.copy() | |
| metrics_history_df[ACTIVATED_COL] = False | |
| if timestamp: | |
| matching_rows = metrics_history_df["Timestamp"] == timestamp | |
| if not matching_rows.any(): | |
| return "⚠️ The selected submission could not be found." | |
| first_match_idx = metrics_history_df.index[matching_rows][0] | |
| metrics_history_df.loc[first_match_idx, ACTIVATED_COL] = True | |
| status_message = f"✅ Showing submission from {timestamp} on the {leaderboard_name} leaderboard." | |
| commit_message = f"Activate {leaderboard_name} entry for {username}" | |
| else: | |
| status_message = f"✅ Your submissions are now hidden from the {leaderboard_name} leaderboard." | |
| commit_message = f"Hide {leaderboard_name} entries for {username}" | |
| _commit_user_metrics_history(username, metrics_history_df, commit_message, context) | |
| return status_message | |
| except Exception as e: | |
| return f"❌ An error occurred: {str(e)}" | |
| def evaluate_and_submit( | |
| profile: gr.OAuthProfile | None, | |
| file_path, | |
| website: str = "N/A", | |
| notes: str = "N/A", | |
| context: str = DEFAULT_LEADERBOARD_CONTEXT, | |
| ): | |
| """Calculates per-horizon scores and securely pushes them to the database.""" | |
| if profile is None: | |
| return "⚠️ Please log in with Hugging Face to submit your results." | |
| username = profile.username | |
| if not file_path: | |
| return "⚠️ Please provide a predictions file." | |
| website = _normalize_optional_text(website) | |
| notes = _normalize_optional_text(notes) | |
| metadata_error = _validate_submission_metadata(website, notes) | |
| if metadata_error: | |
| return metadata_error | |
| try: | |
| user_df = _load_submission(file_path) | |
| column_error = _validate_submission_columns(user_df) | |
| if column_error: | |
| return column_error | |
| ground_truth_df = get_ground_truth(context) | |
| merged_df, merge_error = _merge_submission_with_ground_truth(user_df, ground_truth_df) | |
| if merge_error: | |
| return merge_error | |
| scores = _compute_scores(merged_df) | |
| metrics_history_df = get_user_submission_history(username, context) | |
| remaining_cooldown = _get_remaining_cooldown_seconds(metrics_history_df) | |
| if remaining_cooldown is not None: | |
| return f"⚠️ Please wait {remaining_cooldown}s before submitting again." | |
| action = "Updated" if not metrics_history_df.empty else "Added" | |
| timestamp = datetime.now(timezone.utc).strftime(_TIMESTAMP_FORMAT) | |
| metadata_path = _build_user_metadata_output_path(username, context) | |
| metadata_path.parent.mkdir(parents=True, exist_ok=True) | |
| build_user_metadata_df(website, notes).to_csv(metadata_path, index=False) | |
| updated_metrics_history_df = _append_submission_to_history(metrics_history_df, scores, timestamp) | |
| metrics_path = write_user_metrics_history_file( | |
| output_dir=Path(get_leaderboard_entry_dir(context)), | |
| username=username, | |
| metrics_history_df=updated_metrics_history_df, | |
| ) | |
| api = HfApi() | |
| api.create_commit( | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| operations=[ | |
| CommitOperationAdd( | |
| path_in_repo=get_user_metadata_repo_path(username, context), | |
| path_or_fileobj=str(metadata_path), | |
| ), | |
| CommitOperationAdd( | |
| path_in_repo=get_user_metrics_repo_path(username, context), | |
| path_or_fileobj=str(metrics_path), | |
| ), | |
| ], | |
| commit_message=f"Update {get_leaderboard_name(context)} entry for {username}", | |
| token=HF_TOKEN, | |
| ) | |
| return _build_submission_summary(username, action, scores) | |
| except Exception as e: | |
| return f"❌ An error occurred: {str(e)}" | |