| from __future__ import annotations |
|
|
| import json |
| import os |
| import time |
| from collections import defaultdict |
| from dataclasses import dataclass |
| from functools import lru_cache |
| from pathlib import Path |
| from typing import Any |
|
|
| if "GRADIO_TEMP_DIR" not in os.environ: |
| for candidate in ( |
| Path(__file__).resolve().parent / ".gradio_tmp", |
| Path.cwd() / ".gradio_tmp", |
| Path("/tmp") / "gradio", |
| ): |
| try: |
| candidate.mkdir(parents=True, exist_ok=True) |
| probe = candidate / ".write_probe" |
| probe.write_text("ok", encoding="utf-8") |
| probe.unlink() |
| os.environ["GRADIO_TEMP_DIR"] = str(candidate) |
| break |
| except OSError: |
| continue |
|
|
| import gradio as gr |
| import pandas as pd |
| from huggingface_hub import hf_hub_download |
|
|
|
|
| DEFAULT_GT_LOCAL_PATH = "" |
| DEFAULT_GT_REPO_ID = "nvidia/mmou-gt" |
| DEFAULT_GT_FILENAME = "MMOU.json" |
| DEFAULT_GT_REPO_TYPE = "dataset" |
| DEFAULT_GT_TOKEN_ENV = "HF_TOKEN" |
|
|
| DOMAINS_ORDER = [ |
| "Sports", |
| "Travel", |
| "Video Games", |
| "Daily Life", |
| "Academic Lectures", |
| "Film", |
| "Pranks", |
| "Music", |
| "Animation", |
| "News", |
| ] |
| DURATION_BUCKET_ORDER = ["< 5", "5–10", "10–20", "20–30", "> 30", "Overall"] |
| GT_LETTER_KEYS = ( |
| "correct_option_letter", |
| "correct_answer_letter", |
| "label", |
| "gold_label", |
| "answer_letter", |
| ) |
| GT_DOMAIN_KEYS = ("domain", "category") |
| GT_DURATION_KEYS = ("video_duration", "video_duration_sec", "duration", "duration_sec") |
| GT_SKILL_KEYS = ("question_type", "skills", "skill", "question_types") |
| OPTION_LETTERS = set("ABCDEFGHIJ") |
|
|
| APP_INTRO = """ |
| # MMOU Evaluator |
| |
| Upload a `.json` or `.jsonl` file where each entry contains `question_id` and `answer`. |
| """ |
|
|
| FORMAT_GUIDE = """ |
| ### Submission Format |
| |
| Each entry must contain: |
| |
| - `question_id` |
| - `answer` |
| |
| `answer` must be a single letter from `A` to `J`. Letter matching is case-insensitive. Extra keys are ignored. |
| Rows with empty or `null` answers are ignored. |
| |
| Example JSON: |
| |
| ```json |
| [ |
| {"question_id": "54aaef4d-2c22-476e-a7e7-37efabde2520", "answer": "C"}, |
| {"question_id": "a7f8790d-7828-4ece-a63a-a5d13edf9026", "answer": "B"} |
| ] |
| ``` |
| |
| Example JSONL: |
| |
| ```json |
| {"question_id": "54aaef4d-2c22-476e-a7e7-37efabde2520", "answer": "C"} |
| {"question_id": "a7f8790d-7828-4ece-a63a-a5d13edf9026", "answer": "B"} |
| ``` |
| """ |
|
|
| READY_STATUS_MARKDOWN = "### Ready\nUpload a prediction file and click `Evaluate`." |
| EMPTY_SUMMARY_MARKDOWN = """ |
| ### Summary |
| |
| Run an evaluation to populate the aggregate summary. |
| """ |
|
|
| LAYOUT_CSS = """ |
| .gradio-container { |
| max-width: 1100px !important; |
| margin: 0 auto !important; |
| padding-left: 1rem !important; |
| padding-right: 1rem !important; |
| font-size: 16px !important; |
| } |
| |
| .gradio-container .prose, |
| .gradio-container .gr-markdown, |
| .gradio-container .gr-dataframe, |
| .gradio-container label, |
| .gradio-container button, |
| .gradio-container input, |
| .gradio-container textarea { |
| font-size: 1rem !important; |
| } |
| """ |
|
|
|
|
| @dataclass(frozen=True) |
| class GroundTruthEntry: |
| correct_letter: str |
| domain: str |
| video_duration_sec: float | None |
| skills: tuple[str, ...] |
|
|
|
|
| def stringify(value: Any) -> str: |
| if value is None: |
| return "" |
| if isinstance(value, str): |
| return value.strip() |
| if isinstance(value, (int, float, bool)): |
| return str(value) |
| return json.dumps(value, ensure_ascii=True) |
|
|
|
|
| def coerce_float(value: Any) -> float | None: |
| if value is None or value == "": |
| return None |
| if isinstance(value, (int, float)): |
| return float(value) |
| if isinstance(value, str): |
| try: |
| return float(value.strip()) |
| except ValueError: |
| return None |
| return None |
|
|
|
|
| def first_present(record: dict[str, Any], keys: tuple[str, ...]) -> Any: |
| return next((record[key] for key in keys if record.get(key) not in (None, "", [])), None) |
|
|
|
|
| def parse_skill_list(value: Any) -> tuple[str, ...]: |
| items = value if isinstance(value, list) else ([] if value is None else [value]) |
| cleaned: list[str] = [] |
| seen: set[str] = set() |
| for item in items: |
| text = stringify(item).strip().strip("\"'") |
| if text and text not in seen: |
| seen.add(text) |
| cleaned.append(text) |
| return tuple(cleaned) |
|
|
|
|
| def safe_pct(correct: int, total: int) -> float: |
| return (100.0 * correct / total) if total else 0.0 |
|
|
|
|
| def duration_bucket(minutes: float) -> str: |
| if minutes < 5: |
| return "< 5" |
| if minutes < 10: |
| return "5–10" |
| if minutes < 20: |
| return "10–20" |
| if minutes < 30: |
| return "20–30" |
| return "> 30" |
|
|
|
|
| def normalize_answer(value: Any) -> str: |
| answer = stringify(value).upper() |
| if not answer: |
| return "" |
| if len(answer) != 1 or answer not in OPTION_LETTERS: |
| raise ValueError("Each `answer` must be a single letter from A to J.") |
| return answer |
|
|
|
|
| def load_records(path: str | Path, *, allow_data_key: bool = False) -> tuple[list[dict[str, Any]], str]: |
| file_path = Path(path) |
| suffix = file_path.suffix.lower() |
|
|
| if suffix in {".jsonl", ".ndjson"}: |
| records: list[dict[str, Any]] = [] |
| with file_path.open("r", encoding="utf-8") as handle: |
| for line_number, line in enumerate(handle, start=1): |
| if not line.strip(): |
| continue |
| record = json.loads(line) |
| if not isinstance(record, dict): |
| raise ValueError(f"Line {line_number} in JSONL must be an object.") |
| records.append(record) |
| return records, "jsonl" |
|
|
| with file_path.open("r", encoding="utf-8") as handle: |
| payload = json.load(handle) |
|
|
| if isinstance(payload, list): |
| records = payload |
| elif allow_data_key and isinstance(payload, dict) and isinstance(payload.get("data"), list): |
| records = payload["data"] |
| else: |
| raise ValueError("JSON file must contain a list of objects.") |
|
|
| if not all(isinstance(item, dict) for item in records): |
| raise ValueError("JSON file must contain only objects.") |
|
|
| return records, "json" |
|
|
|
|
| def materialize_ground_truth_file() -> Path: |
| local_path = os.getenv("MMOU_GT_PATH", DEFAULT_GT_LOCAL_PATH).strip() |
| if local_path: |
| path = Path(local_path) |
| if not path.exists(): |
| raise FileNotFoundError( |
| "MMOU_GT_PATH is set, but the file does not exist. " |
| "Update the configured path or mount the private file correctly." |
| ) |
| return path |
|
|
| repo_id = os.getenv("MMOU_GT_REPO_ID", DEFAULT_GT_REPO_ID).strip() |
| filename = os.getenv("MMOU_GT_FILENAME", DEFAULT_GT_FILENAME).strip() |
| if repo_id and filename: |
| repo_type = os.getenv("MMOU_GT_REPO_TYPE", DEFAULT_GT_REPO_TYPE).strip() or "dataset" |
| token_env = os.getenv("MMOU_GT_TOKEN_ENV", DEFAULT_GT_TOKEN_ENV).strip() or "HF_TOKEN" |
| token = os.getenv(token_env) or os.getenv("HF_TOKEN", "") |
| return Path( |
| hf_hub_download( |
| repo_id=repo_id, |
| filename=filename, |
| repo_type=repo_type, |
| token=token or None, |
| ) |
| ) |
|
|
| raise RuntimeError( |
| "Ground truth is not configured. Set MMOU_GT_PATH or " |
| "MMOU_GT_REPO_ID/MMOU_GT_FILENAME before launching the app." |
| ) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def load_ground_truth() -> dict[str, GroundTruthEntry]: |
| records, _ = load_records(materialize_ground_truth_file(), allow_data_key=True) |
| entries: dict[str, GroundTruthEntry] = {} |
|
|
| for record in records: |
| question_id = stringify(record.get("question_id")) |
| if not question_id: |
| continue |
|
|
| correct_letter = next( |
| ( |
| letter |
| for key in GT_LETTER_KEYS |
| if (letter := stringify(record.get(key)).upper()) in OPTION_LETTERS |
| ), |
| "", |
| ) |
| if not correct_letter: |
| continue |
|
|
| entries[question_id] = GroundTruthEntry( |
| correct_letter=correct_letter, |
| domain=stringify(first_present(record, GT_DOMAIN_KEYS)) or "Unknown", |
| video_duration_sec=coerce_float(first_present(record, GT_DURATION_KEYS)), |
| skills=parse_skill_list(first_present(record, GT_SKILL_KEYS)), |
| ) |
|
|
| if not entries: |
| raise RuntimeError("No usable ground-truth question IDs were found.") |
|
|
| return entries |
|
|
|
|
| def build_prediction_map(records: list[dict[str, Any]]) -> tuple[dict[str, str], int, int]: |
| predictions: dict[str, str] = {} |
| duplicates = 0 |
| skipped_empty_answers = 0 |
|
|
| for index, record in enumerate(records, start=1): |
| question_id = stringify(record.get("question_id")) |
| if not question_id: |
| raise ValueError(f"Row {index} is missing `question_id`.") |
| answer = normalize_answer(record.get("answer")) |
| if not answer: |
| skipped_empty_answers += 1 |
| continue |
| if question_id in predictions: |
| duplicates += 1 |
| predictions[question_id] = answer |
|
|
| return predictions, duplicates, skipped_empty_answers |
|
|
|
|
| def bump(stats: dict[str, dict[str, int]], keys: list[str], field: str) -> None: |
| for key in keys: |
| stats[key][field] += 1 |
|
|
|
|
| def make_breakdown_dataframe( |
| stats: dict[str, dict[str, int]], |
| label: str, |
| ordered_labels: list[str] | None = None, |
| ) -> pd.DataFrame: |
| rows = [ |
| { |
| label: name, |
| "Official Accuracy (%)": round(safe_pct(counts["correct"], counts["total"]), 2), |
| "Answered Accuracy (%)": round(safe_pct(counts["correct"], counts["answered"]), 2), |
| "Coverage (%)": round(safe_pct(counts["answered"], counts["total"]), 2), |
| "Correct": counts["correct"], |
| "Answered": counts["answered"], |
| "Total": counts["total"], |
| } |
| for name, counts in stats.items() |
| ] |
|
|
| if not rows: |
| return pd.DataFrame( |
| columns=[ |
| label, |
| "Official Accuracy (%)", |
| "Answered Accuracy (%)", |
| "Coverage (%)", |
| "Correct", |
| "Answered", |
| "Total", |
| ] |
| ) |
|
|
| frame = pd.DataFrame(rows) |
| if ordered_labels: |
| rank = {name: idx for idx, name in enumerate(ordered_labels)} |
| frame["_rank"] = frame[label].map(lambda name: rank.get(name, len(rank))) |
| return frame.sort_values(["_rank", label]).drop(columns="_rank").reset_index(drop=True) |
|
|
| return frame.sort_values(["Answered Accuracy (%)", "Total"], ascending=[False, False]).reset_index(drop=True) |
|
|
|
|
| def build_metrics_markdown(summary: dict[str, Any]) -> str: |
| return "\n".join( |
| [ |
| "### Metrics", |
| f"- Official accuracy: `{summary['official_accuracy_pct']:.2f}%` " |
| f"(`{summary['correct']} / {summary['total_ground_truth']}`)", |
| f"- Answered accuracy: `{summary['answered_accuracy_pct']:.2f}%` " |
| f"(`{summary['correct']} / {summary['answered_predictions']}`)", |
| f"- Coverage: `{summary['coverage_pct']:.2f}%`", |
| f"- Matched IDs: `{summary['matched_prediction_ids']}`", |
| f"- Missing IDs: `{summary['missing_prediction_ids']}`", |
| f"- Extra IDs: `{summary['extra_prediction_ids']}`", |
| f"- Duplicate IDs: `{summary['duplicate_prediction_ids']}`", |
| ] |
| ) |
|
|
|
|
| def build_summary_markdown(domain_df: pd.DataFrame, duration_df: pd.DataFrame, skill_df: pd.DataFrame) -> str: |
| accuracy_column = "Answered Accuracy (%)" |
| best_domain = "n/a" |
| best_duration = "n/a" |
| lowest_skill = "n/a" |
|
|
| if not domain_df.empty: |
| row = domain_df.sort_values([accuracy_column, "Total"], ascending=[False, False]).iloc[0] |
| best_domain = f"{row['Domain']} ({row[accuracy_column]:.2f}%)" |
|
|
| if not duration_df.empty: |
| rows = duration_df[duration_df["Duration Bucket"] != "Overall"] |
| if not rows.empty: |
| row = rows.sort_values([accuracy_column, "Total"], ascending=[False, False]).iloc[0] |
| best_duration = f"{row['Duration Bucket']} ({row[accuracy_column]:.2f}%)" |
|
|
| if not skill_df.empty: |
| rows = skill_df[skill_df["Total"] >= 10] |
| if rows.empty: |
| rows = skill_df |
| row = rows.sort_values([accuracy_column, "Total"], ascending=[True, False]).iloc[0] |
| lowest_skill = f"{row['Skill']} ({row[accuracy_column]:.2f}%)" |
|
|
| return "\n".join( |
| [ |
| "### Summary", |
| f"- Best domain by answered accuracy: `{best_domain}`", |
| f"- Best duration bucket by answered accuracy: `{best_duration}`", |
| f"- Lowest skill bucket by answered accuracy: `{lowest_skill}`", |
| ] |
| ) |
|
|
|
|
| def empty_result(status: str) -> tuple[str, str, str, pd.DataFrame, pd.DataFrame, pd.DataFrame]: |
| return status, "", EMPTY_SUMMARY_MARKDOWN, pd.DataFrame(), pd.DataFrame(), pd.DataFrame() |
|
|
|
|
| def evaluate_submission( |
| prediction_file: str | None, |
| ) -> tuple[str, str, str, pd.DataFrame, pd.DataFrame, pd.DataFrame]: |
| if not prediction_file: |
| return empty_result( |
| "### Upload required\nPlease upload a `.json` or `.jsonl` prediction file before evaluating." |
| ) |
|
|
| started_at = time.time() |
|
|
| try: |
| ground_truth = load_ground_truth() |
| records, file_format = load_records(prediction_file) |
| if not records: |
| raise ValueError("No valid prediction records were found in the uploaded file.") |
|
|
| predictions, duplicate_prediction_ids, skipped_empty_answers = build_prediction_map(records) |
| domain_stats: dict[str, dict[str, int]] = defaultdict(lambda: {"correct": 0, "answered": 0, "total": 0}) |
| duration_stats: dict[str, dict[str, int]] = defaultdict(lambda: {"correct": 0, "answered": 0, "total": 0}) |
| skill_stats: dict[str, dict[str, int]] = defaultdict(lambda: {"correct": 0, "answered": 0, "total": 0}) |
|
|
| correct = 0 |
| answered = 0 |
| gt_ids = set(ground_truth) |
| pred_ids = set(predictions) |
|
|
| for question_id, gt in ground_truth.items(): |
| duration_key = duration_bucket(gt.video_duration_sec / 60.0) if gt.video_duration_sec is not None else None |
| scopes = [ |
| (domain_stats, [gt.domain]), |
| (duration_stats, [duration_key] if duration_key else []), |
| (skill_stats, list(gt.skills)), |
| ] |
|
|
| for stats, keys in scopes: |
| bump(stats, keys, "total") |
|
|
| answer = predictions.get(question_id) |
| if not answer: |
| continue |
|
|
| answered += 1 |
| for stats, keys in scopes: |
| bump(stats, keys, "answered") |
|
|
| if answer == gt.correct_letter: |
| correct += 1 |
| for stats, keys in scopes: |
| bump(stats, keys, "correct") |
|
|
| total_ground_truth = len(ground_truth) |
| duration_stats["Overall"] = {"total": total_ground_truth, "answered": answered, "correct": correct} |
|
|
| summary = { |
| "correct": correct, |
| "answered_predictions": answered, |
| "total_ground_truth": total_ground_truth, |
| "official_accuracy_pct": safe_pct(correct, total_ground_truth), |
| "answered_accuracy_pct": safe_pct(correct, answered), |
| "coverage_pct": safe_pct(answered, total_ground_truth), |
| "matched_prediction_ids": len(pred_ids & gt_ids), |
| "missing_prediction_ids": total_ground_truth - len(pred_ids & gt_ids), |
| "extra_prediction_ids": len(pred_ids - gt_ids), |
| "duplicate_prediction_ids": duplicate_prediction_ids, |
| } |
|
|
| domain_df = make_breakdown_dataframe(domain_stats, "Domain", ordered_labels=DOMAINS_ORDER) |
| duration_df = make_breakdown_dataframe( |
| duration_stats, |
| "Duration Bucket", |
| ordered_labels=DURATION_BUCKET_ORDER, |
| ) |
| skill_df = make_breakdown_dataframe(skill_stats, "Skill") |
|
|
| status_markdown = ( |
| "### Evaluation complete\n" |
| f"- Parsed file format: `{file_format}`\n" |
| f"- Uploaded rows: `{len(records)}`\n" |
| f"- Skipped empty answers: `{skipped_empty_answers}`\n" |
| f"- Evaluation time: `{time.time() - started_at:.2f}s`" |
| ) |
| return ( |
| status_markdown, |
| build_metrics_markdown(summary), |
| build_summary_markdown(domain_df, duration_df, skill_df), |
| domain_df, |
| duration_df, |
| skill_df, |
| ) |
|
|
| except Exception as exc: |
| return empty_result(f"### Evaluation failed\n`{type(exc).__name__}: {exc}`") |
|
|
|
|
| def clear_outputs() -> tuple[None, str, str, str, pd.DataFrame, pd.DataFrame, pd.DataFrame]: |
| return None, READY_STATUS_MARKDOWN, "", EMPTY_SUMMARY_MARKDOWN, pd.DataFrame(), pd.DataFrame(), pd.DataFrame() |
|
|
|
|
| with gr.Blocks(title="MMOU Evaluator", fill_width=False) as demo: |
| gr.Markdown(APP_INTRO) |
|
|
| prediction_file = gr.File(label="Upload prediction file", file_types=[".json", ".jsonl"], type="filepath") |
|
|
| with gr.Row(): |
| evaluate_button = gr.Button("Evaluate", variant="primary") |
| clear_button = gr.Button("Clear") |
|
|
| status_markdown = gr.Markdown(READY_STATUS_MARKDOWN) |
| metrics_markdown = gr.Markdown("") |
| summary_markdown = gr.Markdown(EMPTY_SUMMARY_MARKDOWN) |
| gr.Markdown(FORMAT_GUIDE) |
|
|
| with gr.Tabs(): |
| with gr.Tab("Domain Breakdown"): |
| domain_dataframe = gr.Dataframe(label="Domain breakdown", interactive=False, wrap=True) |
| with gr.Tab("Duration Breakdown"): |
| duration_dataframe = gr.Dataframe(label="Duration breakdown", interactive=False, wrap=True) |
| with gr.Tab("Skill Breakdown"): |
| skill_dataframe = gr.Dataframe(label="Skill breakdown", interactive=False, wrap=True) |
|
|
| evaluate_button.click( |
| fn=evaluate_submission, |
| inputs=[prediction_file], |
| outputs=[ |
| status_markdown, |
| metrics_markdown, |
| summary_markdown, |
| domain_dataframe, |
| duration_dataframe, |
| skill_dataframe, |
| ], |
| ) |
| clear_button.click( |
| fn=clear_outputs, |
| outputs=[ |
| prediction_file, |
| status_markdown, |
| metrics_markdown, |
| summary_markdown, |
| domain_dataframe, |
| duration_dataframe, |
| skill_dataframe, |
| ], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch(theme=gr.themes.Default(), css=LAYOUT_CSS) |
|
|