Spaces:
Running
Running
| import json | |
| import logging | |
| from math import log2 | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| GT_DIR = Path("/data") / "gt" | |
| K_VALUES = [1, 5, 10, 20, 50, 100] | |
| class Evaluator: | |
| def __init__(self, gt_dir: str | Path | None = None): | |
| self.gt_dir = Path(gt_dir) if gt_dir else GT_DIR | |
| self._gt_cache: Dict[str, list] = {} | |
| def _load_gt(self, album_id: str) -> list: | |
| if album_id in self._gt_cache: | |
| return self._gt_cache[album_id] | |
| gt_file = self.gt_dir / f"album{album_id}_test_answer.json" | |
| if not gt_file.exists(): | |
| raise FileNotFoundError(f"Ground truth file not found: {gt_file}") | |
| with open(gt_file, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| self._gt_cache[album_id] = data | |
| return data | |
| def validate_json_format(self, data: Any) -> list[str]: | |
| errors = [] | |
| if not isinstance(data, list): | |
| return ["Root must be a JSON array"] | |
| if len(data) == 0: | |
| return ["Submission is empty"] | |
| for i, item in enumerate(data): | |
| if not isinstance(item, dict): | |
| errors.append(f"Item #{i} must be an object") | |
| continue | |
| if "album_id" not in item or str(item["album_id"]) not in ["1", "2", "3"]: | |
| errors.append(f"Item #{i} 'album_id' must be '1', '2', or '3'") | |
| if "query_en" not in item or not isinstance(item["query_en"], str): | |
| errors.append(f"Item #{i} 'query_en' must be a string") | |
| if ( | |
| "pred" not in item | |
| or not isinstance(item["pred"], list) | |
| or not all(isinstance(x, str) for x in item["pred"]) | |
| ): | |
| errors.append(f"Item #{i} 'pred' must be a list of strings") | |
| return errors | |
| def _dcg_at_k(self, r, k): | |
| r = np.asarray(r, dtype=float)[:k] | |
| if r.size: | |
| return np.sum(r / np.log2(np.arange(2, r.size + 2))) | |
| return 0.0 | |
| def _ndcg_at_k(self, r, k): | |
| dcg_max = self._dcg_at_k(sorted(r, reverse=True), k) | |
| if not dcg_max: | |
| return 0.0 | |
| return self._dcg_at_k(r, k) / dcg_max | |
| def _recall_at_k(self, ground_truth, predictions, k): | |
| k_preds = predictions[:k] | |
| hits = len(set(ground_truth) & set(k_preds)) | |
| if len(ground_truth) == 0: | |
| return 0.0 | |
| return hits / len(ground_truth) | |
| def _evaluate_album(self, album_submissions: dict, album_id: str) -> dict: | |
| """Evaluate a single album.""" | |
| gt_data = self._load_gt(album_id) | |
| gt_map = {item["query_en"]: item for item in gt_data} | |
| metrics_accum = {f"Recall@{k}": [] for k in K_VALUES} | |
| metrics_accum.update({f"NDCG@{k}": [] for k in K_VALUES}) | |
| metrics_accum["Recall"] = [] | |
| metrics_accum["NDCG"] = [] | |
| source_accum = {} | |
| empty_gt_queries = 0 | |
| evaluated_queries = 0 | |
| extraneous_queries = 0 | |
| for q, pred in album_submissions.items(): | |
| if q not in gt_map: | |
| extraneous_queries += 1 | |
| continue | |
| gt_item = gt_map[q] | |
| gt_answers = gt_item.get("ground_truth", []) | |
| source = gt_item.get("Source") | |
| evaluated_queries += 1 | |
| if not gt_answers: | |
| empty_gt_queries += 1 | |
| continue | |
| r = [1 if p in gt_answers else 0 for p in pred] | |
| dcg_r = [1.0] * len(gt_answers) | |
| m = {} | |
| for k in K_VALUES: | |
| m[f"Recall@{k}"] = self._recall_at_k(gt_answers, pred, k) | |
| idcg = self._dcg_at_k(dcg_r, k) | |
| ndcg = self._dcg_at_k(r, k) / idcg if idcg > 0 else 0.0 | |
| m[f"NDCG@{k}"] = ndcg | |
| metrics_accum[f"Recall@{k}"].append(m[f"Recall@{k}"]) | |
| metrics_accum[f"NDCG@{k}"].append(m[f"NDCG@{k}"]) | |
| m["Recall"] = sum(r) / len(gt_answers) | |
| idcg_all = self._dcg_at_k(dcg_r, len(gt_answers)) | |
| ndcg_all = self._dcg_at_k(r, len(r)) / idcg_all if idcg_all > 0 else 0.0 | |
| m["NDCG"] = ndcg_all | |
| metrics_accum["Recall"].append(m["Recall"]) | |
| metrics_accum["NDCG"].append(m["NDCG"]) | |
| if source is not None: | |
| if source not in source_accum: | |
| source_accum[source] = {f"Recall@{_k}": [] for _k in K_VALUES} | |
| source_accum[source].update({f"NDCG@{_k}": [] for _k in K_VALUES}) | |
| source_accum[source]["Recall"] = [] | |
| source_accum[source]["NDCG"] = [] | |
| for k in K_VALUES: | |
| source_accum[source][f"Recall@{k}"].append(m[f"Recall@{k}"]) | |
| source_accum[source][f"NDCG@{k}"].append(m[f"NDCG@{k}"]) | |
| source_accum[source]["Recall"].append(m["Recall"]) | |
| source_accum[source]["NDCG"].append(m["NDCG"]) | |
| global_metrics = { | |
| k: float(np.mean(v)) if v else 0.0 for k, v in metrics_accum.items() | |
| } | |
| return { | |
| "global_metrics": global_metrics, | |
| "source_metrics": { | |
| src: {k: float(np.mean(v)) if v else 0.0 for k, v in m_dict.items()} | |
| for src, m_dict in source_accum.items() | |
| }, | |
| "empty_gt_ratio": empty_gt_queries / evaluated_queries if evaluated_queries > 0 else 0.0, | |
| "evaluated_queries": evaluated_queries, | |
| "total_gt_queries": len(gt_data), | |
| "is_partial": evaluated_queries < len(gt_data), | |
| "extraneous_queries": extraneous_queries, | |
| } | |
| def evaluate(self, submission_data: list) -> dict: | |
| albums = {} | |
| for item in submission_data: | |
| a_id = str(item["album_id"]) | |
| if a_id not in albums: | |
| albums[a_id] = {} | |
| albums[a_id][item["query_en"]] = item["pred"] | |
| if not albums: | |
| raise ValueError("No valid albums found in submission.") | |
| # Evaluate each album separately | |
| per_album = {} | |
| for a_id in sorted(albums.keys()): | |
| per_album[a_id] = self._evaluate_album(albums[a_id], a_id) | |
| # Compute averaged metrics across all albums | |
| avg_metrics = {} | |
| for metric_key in per_album[list(per_album.keys())[0]]["global_metrics"].keys(): | |
| values = [alb["global_metrics"][metric_key] for alb in per_album.values() if metric_key in alb["global_metrics"]] | |
| avg_metrics[metric_key] = float(np.mean(values)) if values else 0.0 | |
| total_evaluated = sum(alb["evaluated_queries"] for alb in per_album.values()) | |
| total_gt = sum(alb["total_gt_queries"] for alb in per_album.values()) | |
| total_extraneous = sum(alb.get("extraneous_queries", 0) for alb in per_album.values()) | |
| result = { | |
| "per_album": per_album, | |
| "global_metrics": avg_metrics, | |
| "evaluated_queries": total_evaluated, | |
| "total_gt_queries": total_gt, | |
| "is_partial": total_evaluated < total_gt, | |
| "albums": sorted(albums.keys()), | |
| "extraneous_queries": total_extraneous, | |
| } | |
| # Build warning / notice messages | |
| msgs = [] | |
| if total_extraneous > 0: | |
| msgs.append(f"{total_extraneous} extraneous queries were ignored (not in current GT). This may be caused by an outdated test.json or extra queries. Valid queries: {total_evaluated}/{total_gt}.") | |
| if result["is_partial"]: | |
| missing_albums = [a for a in ["1", "2", "3"] if a not in albums] | |
| missing_queries = total_gt - total_evaluated | |
| parts = [] | |
| if missing_albums: | |
| parts.append(f"Missing albums: {', '.join(missing_albums)}") | |
| if missing_queries > 0: | |
| parts.append(f"Missing {missing_queries} queries ({total_evaluated}/{total_gt} submitted)") | |
| msgs.append("Submission incomplete. " + "; ".join(parts) + ". Only full submissions are eligible for leaderboard ranking.") | |
| if msgs: | |
| result["warning"] = " ".join(msgs) | |
| return result | |