import json import logging from math import log2 from pathlib import Path from typing import Any, Dict, List import numpy as np logger = logging.getLogger(__name__) GT_DIR = Path("/data") / "gt" K_VALUES = [1, 5, 10, 20, 50, 100] class Evaluator: def __init__(self, gt_dir: str | Path | None = None): self.gt_dir = Path(gt_dir) if gt_dir else GT_DIR self._gt_cache: Dict[str, list] = {} def _load_gt(self, album_id: str) -> list: if album_id in self._gt_cache: return self._gt_cache[album_id] gt_file = self.gt_dir / f"album{album_id}_test_answer.json" if not gt_file.exists(): raise FileNotFoundError(f"Ground truth file not found: {gt_file}") with open(gt_file, "r", encoding="utf-8") as f: data = json.load(f) self._gt_cache[album_id] = data return data def validate_json_format(self, data: Any) -> list[str]: errors = [] if not isinstance(data, list): return ["Root must be a JSON array"] if len(data) == 0: return ["Submission is empty"] for i, item in enumerate(data): if not isinstance(item, dict): errors.append(f"Item #{i} must be an object") continue if "album_id" not in item or str(item["album_id"]) not in ["1", "2", "3"]: errors.append(f"Item #{i} 'album_id' must be '1', '2', or '3'") if "query_en" not in item or not isinstance(item["query_en"], str): errors.append(f"Item #{i} 'query_en' must be a string") if ( "pred" not in item or not isinstance(item["pred"], list) or not all(isinstance(x, str) for x in item["pred"]) ): errors.append(f"Item #{i} 'pred' must be a list of strings") return errors def _dcg_at_k(self, r, k): r = np.asarray(r, dtype=float)[:k] if r.size: return np.sum(r / np.log2(np.arange(2, r.size + 2))) return 0.0 def _ndcg_at_k(self, r, k): dcg_max = self._dcg_at_k(sorted(r, reverse=True), k) if not dcg_max: return 0.0 return self._dcg_at_k(r, k) / dcg_max def _recall_at_k(self, ground_truth, predictions, k): k_preds = predictions[:k] hits = len(set(ground_truth) & set(k_preds)) if len(ground_truth) == 0: return 0.0 return hits / len(ground_truth) def _evaluate_album(self, album_submissions: dict, album_id: str) -> dict: """Evaluate a single album.""" gt_data = self._load_gt(album_id) gt_map = {item["query_en"]: item for item in gt_data} metrics_accum = {f"Recall@{k}": [] for k in K_VALUES} metrics_accum.update({f"NDCG@{k}": [] for k in K_VALUES}) metrics_accum["Recall"] = [] metrics_accum["NDCG"] = [] source_accum = {} empty_gt_queries = 0 evaluated_queries = 0 extraneous_queries = 0 for q, pred in album_submissions.items(): if q not in gt_map: extraneous_queries += 1 continue gt_item = gt_map[q] gt_answers = gt_item.get("ground_truth", []) source = gt_item.get("Source") evaluated_queries += 1 if not gt_answers: empty_gt_queries += 1 continue r = [1 if p in gt_answers else 0 for p in pred] dcg_r = [1.0] * len(gt_answers) m = {} for k in K_VALUES: m[f"Recall@{k}"] = self._recall_at_k(gt_answers, pred, k) idcg = self._dcg_at_k(dcg_r, k) ndcg = self._dcg_at_k(r, k) / idcg if idcg > 0 else 0.0 m[f"NDCG@{k}"] = ndcg metrics_accum[f"Recall@{k}"].append(m[f"Recall@{k}"]) metrics_accum[f"NDCG@{k}"].append(m[f"NDCG@{k}"]) m["Recall"] = sum(r) / len(gt_answers) idcg_all = self._dcg_at_k(dcg_r, len(gt_answers)) ndcg_all = self._dcg_at_k(r, len(r)) / idcg_all if idcg_all > 0 else 0.0 m["NDCG"] = ndcg_all metrics_accum["Recall"].append(m["Recall"]) metrics_accum["NDCG"].append(m["NDCG"]) if source is not None: if source not in source_accum: source_accum[source] = {f"Recall@{_k}": [] for _k in K_VALUES} source_accum[source].update({f"NDCG@{_k}": [] for _k in K_VALUES}) source_accum[source]["Recall"] = [] source_accum[source]["NDCG"] = [] for k in K_VALUES: source_accum[source][f"Recall@{k}"].append(m[f"Recall@{k}"]) source_accum[source][f"NDCG@{k}"].append(m[f"NDCG@{k}"]) source_accum[source]["Recall"].append(m["Recall"]) source_accum[source]["NDCG"].append(m["NDCG"]) global_metrics = { k: float(np.mean(v)) if v else 0.0 for k, v in metrics_accum.items() } return { "global_metrics": global_metrics, "source_metrics": { src: {k: float(np.mean(v)) if v else 0.0 for k, v in m_dict.items()} for src, m_dict in source_accum.items() }, "empty_gt_ratio": empty_gt_queries / evaluated_queries if evaluated_queries > 0 else 0.0, "evaluated_queries": evaluated_queries, "total_gt_queries": len(gt_data), "is_partial": evaluated_queries < len(gt_data), "extraneous_queries": extraneous_queries, } def evaluate(self, submission_data: list) -> dict: albums = {} for item in submission_data: a_id = str(item["album_id"]) if a_id not in albums: albums[a_id] = {} albums[a_id][item["query_en"]] = item["pred"] if not albums: raise ValueError("No valid albums found in submission.") # Evaluate each album separately per_album = {} for a_id in sorted(albums.keys()): per_album[a_id] = self._evaluate_album(albums[a_id], a_id) # Compute averaged metrics across all albums avg_metrics = {} for metric_key in per_album[list(per_album.keys())[0]]["global_metrics"].keys(): values = [alb["global_metrics"][metric_key] for alb in per_album.values() if metric_key in alb["global_metrics"]] avg_metrics[metric_key] = float(np.mean(values)) if values else 0.0 total_evaluated = sum(alb["evaluated_queries"] for alb in per_album.values()) total_gt = sum(alb["total_gt_queries"] for alb in per_album.values()) total_extraneous = sum(alb.get("extraneous_queries", 0) for alb in per_album.values()) result = { "per_album": per_album, "global_metrics": avg_metrics, "evaluated_queries": total_evaluated, "total_gt_queries": total_gt, "is_partial": total_evaluated < total_gt, "albums": sorted(albums.keys()), "extraneous_queries": total_extraneous, } # Build warning / notice messages msgs = [] if total_extraneous > 0: msgs.append(f"{total_extraneous} extraneous queries were ignored (not in current GT). This may be caused by an outdated test.json or extra queries. Valid queries: {total_evaluated}/{total_gt}.") if result["is_partial"]: missing_albums = [a for a in ["1", "2", "3"] if a not in albums] missing_queries = total_gt - total_evaluated parts = [] if missing_albums: parts.append(f"Missing albums: {', '.join(missing_albums)}") if missing_queries > 0: parts.append(f"Missing {missing_queries} queries ({total_evaluated}/{total_gt} submitted)") msgs.append("Submission incomplete. " + "; ".join(parts) + ". Only full submissions are eligible for leaderboard ranking.") if msgs: result["warning"] = " ".join(msgs) return result