| | import gradio as gr |
| | import json |
| | import os |
| | import numpy as np |
| | from cryptography.fernet import Fernet |
| | from collections import defaultdict |
| | from sklearn.metrics import ndcg_score |
| |
|
| | def load_and_decrypt_qrel(secret_key): |
| | try: |
| | with open("data/answer.enc", "rb") as enc_file: |
| | encrypted_data = enc_file.read() |
| | cipher = Fernet(secret_key.encode()) |
| | decrypted_data = cipher.decrypt(encrypted_data).decode("utf-8") |
| | raw_data = json.loads(decrypted_data) |
| |
|
| | |
| | qrel_dict = defaultdict(lambda: defaultdict(dict)) |
| | for dataset, records in raw_data.items(): |
| | for item in records: |
| | qid, cid, score = item["query_id"], item["corpus_id"], item["score"] |
| | qrel_dict[dataset][qid][cid] = score |
| | return qrel_dict |
| | except Exception as e: |
| | raise ValueError(f"β Failed to decrypt answer file: {str(e)}") |
| |
|
| | def recall_at_k(corpus_top_100_list, relevant_ids, k=1): |
| | return int(any(item in relevant_ids for item in corpus_top_100_list[:k])) |
| |
|
| | def ndcg_at_k(corpus_top_100_list, rel_dict, k): |
| | all_items = list(dict.fromkeys(corpus_top_100_list + list(rel_dict.keys()))) |
| | y_true = [rel_dict.get(item, 0) for item in all_items] |
| | y_score = [len(all_items) - i for i in range(len(all_items))] |
| | return ndcg_score([y_true], [y_score], k=k) |
| |
|
| | def evaluate(pred_data, qrel_dict): |
| | results = {} |
| | for dataset, queries in pred_data.items(): |
| | if dataset not in qrel_dict: |
| | continue |
| |
|
| | recall_1, ndcg_10, ndcg_100 = [], [], [] |
| | for item in queries: |
| | qid = item["query_id"] |
| | corpus_top_100_list = item["corpus_top_100_list"].split(",") |
| | corpus_top_100_list = [x.strip() for x in corpus_top_100_list if x.strip()] |
| | rel_dict = qrel_dict[dataset].get(qid, {}) |
| | relevant_ids = [cid for cid, score in rel_dict.items() if score > 0] |
| |
|
| | recall_1.append(recall_at_k(corpus_top_100_list, relevant_ids, 1)) |
| | ndcg_10.append(ndcg_at_k(corpus_top_100_list, rel_dict, 10)) |
| | ndcg_100.append(ndcg_at_k(corpus_top_100_list, rel_dict, 100)) |
| |
|
| | results[dataset] = { |
| | "Recall@1": round(np.mean(recall_1) * 100, 2), |
| | "NDCG@10": round(np.mean(ndcg_10) * 100, 2), |
| | "NDCG@100": round(np.mean(ndcg_100) * 100, 2), |
| | } |
| |
|
| | return results |
| |
|
| | def process_json(file): |
| | try: |
| | pred_data = json.load(open(file)) |
| | except Exception as e: |
| | return f"β Invalid JSON format: {str(e)}" |
| |
|
| | secret_key = os.getenv("SECRET_KEY") |
| | if not secret_key: |
| | return "β SECRET_KEY environment variable not set. Please configure it in your Hugging Face Space." |
| |
|
| | try: |
| | qrel_dict = load_and_decrypt_qrel(secret_key) |
| | except Exception as e: |
| | return str(e) |
| |
|
| | try: |
| | metrics = evaluate(pred_data, qrel_dict) |
| | return json.dumps(metrics, indent=2) |
| | except Exception as e: |
| | return f"β Error during evaluation: {str(e)}" |
| |
|
| | def main_gradio(): |
| | example_json_html = ( |
| | '<pre><code>{<br>' |
| | ' "Google_WIT": [<br>' |
| | ' {"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},<br>' |
| | ' {"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}<br>' |
| | ' ],<br>' |
| | ' "MSCOCO": [<br>' |
| | ' {"query_id": "1", "corpus_top_100_list": "122, 35, 22, ..."}<br>' |
| | ' ],<br>' |
| | ' "OVEN": [<br>' |
| | ' {"query_id": "1", "corpus_top_100_list": "11, 15, 22, ..."}<br>' |
| | ' ],<br>' |
| | ' "VisualNews": [<br>' |
| | ' {"query_id": "1", "corpus_top_100_list": "101, 35, 77, ..."}<br>' |
| | ' ]<br>' |
| | '}</code></pre>' |
| | ) |
| |
|
| | gr.Interface( |
| | fn=process_json, |
| | inputs=gr.File(label="Upload Retrieval Result (JSON)"), |
| | outputs=gr.Textbox(label="Evaluation Results"), |
| | title="π Automated Evaluation of MixBench", |
| | description=( |
| | "Please upload your model's retrieval result on MixBench (in JSON format) to automatically evaluate its performance.<br><br>" |
| | "For each subset (e.g., <code>MSCOCO</code>, <code>Google_WIT</code>, <code>VisualNews</code>, <code>OVEN</code>), " |
| | "we compute:<br>" |
| | "- <strong>Recall@1</strong><br>" |
| | "- <strong>NDCG@10</strong><br>" |
| | "- <strong>NDCG@100</strong><br><br>" |
| | "Expected input JSON format:<br><br>" + example_json_html + |
| | "<br>To find valid query IDs, see the " |
| | "<a href='https://huggingface.co/datasets/mixed-modality-search/MixBench2025/viewer/Google_WIT/mixed_corpus' target='_blank'>MixBench2025 dataset viewer</a>." |
| | ) |
| | ).launch(share=True) |
| |
|
| | if __name__ == "__main__": |
| | main_gradio() |
| |
|