Spaces:
Sleeping
Sleeping
| """Evaluation runner with UI updates.""" | |
| import hashlib | |
| import json | |
| import logging | |
| import time | |
| import traceback | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Optional | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from qdrant_client.models import FieldCondition, Filter, MatchValue | |
| from visual_rag import VisualEmbedder | |
| from visual_rag.retrieval import MultiVectorRetriever | |
| from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset | |
| from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k | |
| from demo.qdrant_utils import get_qdrant_credentials | |
| TORCH_DTYPE_MAP = { | |
| "float16": torch.float16, | |
| "float32": torch.float32, | |
| "bfloat16": torch.bfloat16, | |
| } | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") | |
| def _stable_uuid(text: str) -> str: | |
| """Generate a stable UUID from text (same as benchmark script).""" | |
| hex_str = hashlib.sha256(text.encode("utf-8")).hexdigest()[:32] | |
| return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}" | |
| def _union_point_id(*, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]) -> str: | |
| """Generate union point ID (same as benchmark script).""" | |
| ns = f"{union_namespace}::{dataset_name}" if union_namespace else dataset_name | |
| return _stable_uuid(f"{ns}::{source_doc_id}") | |
| def _remap_qrels_to_union_ids( | |
| qrels: Dict[str, Dict[str, int]], | |
| corpus: List[Any], | |
| dataset_name: str, | |
| collection_name: str, | |
| ) -> Dict[str, Dict[str, int]]: | |
| """Remap qrels doc_ids from original format to union_doc_id format (matching benchmark).""" | |
| id_map: Dict[str, str] = {} | |
| for doc in corpus: | |
| source_doc_id = str((doc.payload or {}).get("source_doc_id") or doc.doc_id) | |
| id_map[str(doc.doc_id)] = _union_point_id( | |
| dataset_name=dataset_name, | |
| source_doc_id=source_doc_id, | |
| union_namespace=collection_name, | |
| ) | |
| remapped: Dict[str, Dict[str, int]] = {} | |
| for qid, rels in qrels.items(): | |
| out_rels: Dict[str, int] = {} | |
| for did, score in rels.items(): | |
| mapped = id_map.get(str(did)) | |
| if mapped: | |
| out_rels[mapped] = int(score) | |
| if out_rels: | |
| remapped[qid] = out_rels | |
| return remapped | |
| def get_doc_id_from_result(r: Dict[str, Any], use_original: bool = True) -> str: | |
| """Extract document ID from search result. | |
| Args: | |
| r: Search result dict with 'id' and 'payload' | |
| use_original: If True, prefer original doc_id for matching with qrels. | |
| If False, prefer union_doc_id (Qdrant point ID). | |
| """ | |
| payload = r.get("payload", {}) | |
| if use_original: | |
| doc_id = ( | |
| payload.get("doc_id") | |
| or payload.get("source_doc_id") | |
| or payload.get("corpus-id") | |
| or payload.get("union_doc_id") | |
| or str(r.get("id", "")) | |
| ) | |
| else: | |
| doc_id = ( | |
| payload.get("union_doc_id") | |
| or str(r.get("id", "")) | |
| or payload.get("doc_id") | |
| ) | |
| return str(doc_id) | |
| def run_evaluation_with_ui(config: Dict[str, Any]): | |
| st.divider() | |
| print("=" * 60) | |
| print("[EVAL] Starting evaluation via UI") | |
| print("=" * 60) | |
| url, api_key = get_qdrant_credentials() | |
| if not url: | |
| st.error("QDRANT_URL not configured") | |
| return | |
| datasets = config.get("datasets", []) | |
| collection = config["collection"] | |
| model = config.get("model", "vidore/colpali-v1.3") | |
| mode = config.get("mode", "single_full") | |
| top_k = config.get("top_k", 100) | |
| prefetch_k = config.get("prefetch_k", 256) | |
| stage1_mode = config.get("stage1_mode", "tokens_vs_tiles") | |
| stage1_k = config.get("stage1_k", 1000) | |
| stage2_k = config.get("stage2_k", 300) | |
| prefer_grpc = config.get("prefer_grpc", True) | |
| torch_dtype = config.get("torch_dtype", "float16") | |
| evaluation_scope = config.get("evaluation_scope", "union") | |
| print(f"[EVAL] βββββββββββββββββββββββββββββββββββββββββββββββββββ") | |
| print(f"[EVAL] Collection: {collection}") | |
| print(f"[EVAL] Model: {model}") | |
| print(f"[EVAL] Mode: {mode}, Scope: {evaluation_scope}") | |
| print(f"[EVAL] Datasets: {datasets}") | |
| print(f"[EVAL] Query embedding dtype: {torch_dtype} (vectors already indexed)") | |
| print(f"[EVAL] βββββββββββββββββββββββββββββββββββββββββββββββββββ") | |
| phase1_container = st.container() | |
| phase2_container = st.container() | |
| phase3_container = st.container() | |
| results_container = st.container() | |
| try: | |
| with phase1_container: | |
| st.markdown("##### π€ Phase 1: Loading Model") | |
| model_status = st.empty() | |
| model_status.info(f"Loading `{model.split('/')[-1]}`...") | |
| print(f"[EVAL] Loading embedder: {model}") | |
| torch_dtype_obj = TORCH_DTYPE_MAP.get(torch_dtype, torch.float16) | |
| qdrant_dtype = config.get("qdrant_vector_dtype", "float16") | |
| output_dtype_obj = np.float16 if qdrant_dtype == "float16" else np.float32 | |
| embedder = VisualEmbedder( | |
| model_name=model, | |
| torch_dtype=torch_dtype_obj, | |
| output_dtype=output_dtype_obj, | |
| ) | |
| embedder._load_model() | |
| print(f"[EVAL] Embedder loaded (torch_dtype={torch_dtype}, output_dtype={qdrant_dtype})") | |
| model_status.success(f"β Model `{model.split('/')[-1]}` loaded") | |
| retriever_status = st.empty() | |
| retriever_status.info(f"Connecting to collection `{collection}`...") | |
| print(f"[EVAL] Connecting to Qdrant collection: {collection}") | |
| retriever = MultiVectorRetriever( | |
| collection_name=collection, | |
| model_name=model, | |
| qdrant_url=url, | |
| qdrant_api_key=api_key, | |
| prefer_grpc=prefer_grpc, | |
| embedder=embedder, | |
| ) | |
| print(f"[EVAL] Connected to Qdrant") | |
| retriever_status.success(f"β Connected to `{collection}`") | |
| with phase2_container: | |
| st.markdown("##### π Phase 2: Loading Datasets") | |
| dataset_data = {} | |
| total_queries = 0 | |
| max_queries_per_ds = config.get("max_queries") | |
| for ds_name in datasets: | |
| ds_status = st.empty() | |
| ds_short = ds_name.split("/")[-1] | |
| ds_status.info(f"Loading `{ds_short}`...") | |
| print(f"[EVAL] Loading dataset: {ds_name}") | |
| corpus, queries, qrels = load_vidore_beir_dataset(ds_name) | |
| print(f"[EVAL] Remapping qrels to union_doc_id format for collection={collection}") | |
| remapped_qrels = _remap_qrels_to_union_ids(qrels, corpus, ds_name, collection) | |
| print(f"[EVAL] Remapped {len(qrels)} -> {len(remapped_qrels)} queries with valid rels") | |
| if evaluation_scope == "per_dataset" and max_queries_per_ds: | |
| queries = queries[:max_queries_per_ds] | |
| dataset_data[ds_name] = { | |
| "queries": queries, | |
| "qrels": remapped_qrels, | |
| "num_docs": len(corpus), | |
| } | |
| total_queries += len(queries) | |
| print(f"[EVAL] Loaded {ds_name}: {len(corpus)} docs, {len(queries)} queries") | |
| ds_status.success(f"β `{ds_short}`: {len(corpus)} docs, {len(queries)} queries") | |
| if evaluation_scope == "union" and max_queries_per_ds and max_queries_per_ds < total_queries: | |
| total_queries = max_queries_per_ds | |
| print(f"[EVAL] Will limit to {total_queries} total queries (union mode)") | |
| embed_status = st.empty() | |
| embed_status.info(f"Embedding queries...") | |
| with phase3_container: | |
| st.markdown("##### π― Phase 3: Running Evaluation") | |
| metrics_collectors = { | |
| "ndcg@5": [], "ndcg@10": [], | |
| "recall@5": [], "recall@10": [], | |
| "mrr@5": [], "mrr@10": [], | |
| } | |
| latencies = [] | |
| log_lines = [] | |
| metrics_by_dataset = {} | |
| if evaluation_scope == "per_dataset": | |
| overall_progress = st.progress(0.0) | |
| datasets_done = 0 | |
| for ds_name, ds_info in dataset_data.items(): | |
| ds_short = ds_name.split("/")[-1] | |
| st.markdown(f"**Evaluating `{ds_short}`**") | |
| queries = ds_info["queries"] | |
| qrels = ds_info["qrels"] | |
| if not queries: | |
| continue | |
| print(f"[EVAL] Embedding {len(queries)} queries for {ds_short}...") | |
| query_texts = [q.text for q in queries] | |
| query_embeddings = embedder.embed_queries(query_texts, show_progress=False) | |
| print(f"[EVAL] Queries embedded for {ds_short}") | |
| ds_filter = Filter( | |
| must=[FieldCondition(key="dataset", match=MatchValue(value=ds_name))] | |
| ) | |
| print(f"[EVAL] Using filter: dataset={ds_name}") | |
| progress_bar = st.progress(0.0) | |
| eval_status = st.empty() | |
| log_area = st.empty() | |
| ds_metrics = {"ndcg@5": [], "ndcg@10": [], "recall@5": [], "recall@10": [], "mrr@5": [], "mrr@10": []} | |
| ds_latencies = [] | |
| ds_log_lines = [] | |
| eval_status.info(f"Evaluating {len(queries)} queries...") | |
| print(f"[EVAL] Starting per-dataset evaluation: {ds_short}, {len(queries)} queries") | |
| for i, (q, qemb) in enumerate(zip(queries, query_embeddings)): | |
| start = time.time() | |
| if isinstance(qemb, torch.Tensor): | |
| qemb_np = qemb.detach().cpu().numpy() | |
| else: | |
| qemb_np = qemb.numpy() if hasattr(qemb, 'numpy') else np.array(qemb) | |
| results = retriever.search_embedded( | |
| query_embedding=qemb_np, | |
| top_k=max(100, top_k), | |
| mode=mode, | |
| prefetch_k=prefetch_k, | |
| stage1_mode=stage1_mode, | |
| stage1_k=stage1_k, | |
| stage2_k=stage2_k, | |
| filter_obj=ds_filter, | |
| ) | |
| ds_latencies.append((time.time() - start) * 1000) | |
| latencies.append(ds_latencies[-1]) | |
| ranking = [str(r["id"]) for r in results] | |
| rels = qrels.get(q.query_id, {}) | |
| if i == 0: | |
| print(f"[EVAL] First query for {ds_short} - query_id: {q.query_id}") | |
| print(f"[EVAL] Top 3 retrieved doc_ids: {ranking[:3]}") | |
| print(f"[EVAL] Expected doc_ids (qrels): {list(rels.keys())[:3]}") | |
| print(f"[EVAL] qrels has {len(qrels)} queries, this query in qrels: {q.query_id in qrels}") | |
| if results: | |
| r0 = results[0] | |
| print(f"[EVAL] Sample result payload keys: {list(r0.get('payload', {}).keys())}") | |
| p = r0.get("payload", {}) | |
| print(f"[EVAL] Sample payload doc_id={p.get('doc_id')}, union_doc_id={p.get('union_doc_id')}, source_doc_id={p.get('source_doc_id')}") | |
| has_match = any(rid in rels for rid in ranking[:10]) | |
| print(f"[EVAL] Any match in top-10? {has_match}") | |
| for k_name, k_val in [("ndcg@5", 5), ("ndcg@10", 10)]: | |
| ds_metrics[k_name].append(ndcg_at_k(ranking, rels, k=k_val)) | |
| for k_name, k_val in [("recall@5", 5), ("recall@10", 10)]: | |
| ds_metrics[k_name].append(recall_at_k(ranking, rels, k=k_val)) | |
| for k_name, k_val in [("mrr@5", 5), ("mrr@10", 10)]: | |
| ds_metrics[k_name].append(mrr_at_k(ranking, rels, k=k_val)) | |
| progress = (i + 1) / len(queries) | |
| progress_bar.progress(progress) | |
| eval_status.info(f"π― {i+1}/{len(queries)} ({int(progress*100)}%) β latency: {np.mean(ds_latencies):.0f}ms") | |
| log_interval = max(5, len(queries) // 10) | |
| if (i + 1) % log_interval == 0 and i > 0: | |
| cur_ndcg = np.mean(ds_metrics["ndcg@10"]) | |
| cur_lat = np.mean(ds_latencies[1:]) if len(ds_latencies) > 1 else ds_latencies[0] | |
| ds_log_lines.append(f"[{i+1}/{len(queries)}] NDCG@10={cur_ndcg:.4f}, lat={cur_lat:.0f}ms") | |
| log_area.code("\n".join(ds_log_lines[-5:]), language="text") | |
| print(f"[EVAL] {ds_short} {i+1}/{len(queries)}: NDCG@10={cur_ndcg:.4f}, lat={cur_lat:.0f}ms") | |
| progress_bar.progress(1.0) | |
| ds_final = {k: float(np.mean(v)) for k, v in ds_metrics.items()} | |
| ds_final["avg_latency_ms"] = float(np.mean(ds_latencies)) | |
| ds_final["num_queries"] = len(queries) | |
| metrics_by_dataset[ds_name] = ds_final | |
| for k, v in ds_metrics.items(): | |
| metrics_collectors[k].extend(v) | |
| eval_status.success(f"β `{ds_short}`: NDCG@10={ds_final['ndcg@10']:.4f}, latency={ds_final['avg_latency_ms']:.0f}ms") | |
| print(f"[EVAL] {ds_short} DONE: NDCG@10={ds_final['ndcg@10']:.4f}") | |
| datasets_done += 1 | |
| overall_progress.progress(datasets_done / len(datasets)) | |
| overall_progress.progress(1.0) | |
| embed_status.success(f"β All queries embedded") | |
| total_queries = sum(d["num_queries"] for d in metrics_by_dataset.values()) | |
| else: | |
| all_queries = [] | |
| all_qrels = {} | |
| for ds_name, ds_info in dataset_data.items(): | |
| all_queries.extend(ds_info["queries"]) | |
| for qid, rels in ds_info["qrels"].items(): | |
| all_qrels[qid] = rels | |
| sample_qrel_keys = list(all_qrels.keys())[:3] | |
| sample_doc_ids = [] | |
| for qid in sample_qrel_keys: | |
| sample_doc_ids.extend(list(all_qrels[qid].keys())[:2]) | |
| print(f"[EVAL] all_qrels built: {len(all_qrels)} queries") | |
| print(f"[EVAL] Sample qrel query_ids: {sample_qrel_keys}") | |
| print(f"[EVAL] Sample qrel doc_ids: {sample_doc_ids[:5]}") | |
| max_q = config.get("max_queries") | |
| if max_q and max_q < len(all_queries): | |
| all_queries = all_queries[:max_q] | |
| total_queries = len(all_queries) | |
| print(f"[EVAL] Embedding {total_queries} queries (union mode)...") | |
| query_texts = [q.text for q in all_queries] | |
| query_embeddings = embedder.embed_queries(query_texts, show_progress=False) | |
| print(f"[EVAL] Queries embedded") | |
| embed_status.success(f"β {total_queries} queries embedded") | |
| progress_bar = st.progress(0.0) | |
| eval_status = st.empty() | |
| log_area = st.empty() | |
| eval_status.info(f"Evaluating {total_queries} queries in `{mode}` mode...") | |
| print(f"[EVAL] Starting union evaluation: {total_queries} queries, mode={mode}") | |
| for i, (q, qemb) in enumerate(zip(all_queries, query_embeddings)): | |
| start = time.time() | |
| if isinstance(qemb, torch.Tensor): | |
| qemb_np = qemb.detach().cpu().numpy() | |
| else: | |
| qemb_np = qemb.numpy() if hasattr(qemb, 'numpy') else np.array(qemb) | |
| results = retriever.search_embedded( | |
| query_embedding=qemb_np, | |
| top_k=max(100, top_k), | |
| mode=mode, | |
| prefetch_k=prefetch_k, | |
| stage1_mode=stage1_mode, | |
| stage1_k=stage1_k, | |
| stage2_k=stage2_k, | |
| ) | |
| latencies.append((time.time() - start) * 1000) | |
| ranking = [str(r["id"]) for r in results] | |
| rels = all_qrels.get(q.query_id, {}) | |
| if i == 0: | |
| print(f"[EVAL] First query - query_id: {q.query_id}") | |
| print(f"[EVAL] Top 3 retrieved doc_ids: {ranking[:3]}") | |
| print(f"[EVAL] Expected doc_ids (qrels): {list(rels.keys())[:3]}") | |
| print(f"[EVAL] all_qrels has {len(all_qrels)} queries, this query in qrels: {q.query_id in all_qrels}") | |
| if results: | |
| r0 = results[0] | |
| print(f"[EVAL] Sample result payload keys: {list(r0.get('payload', {}).keys())}") | |
| p = r0.get("payload", {}) | |
| print(f"[EVAL] Sample payload doc_id={p.get('doc_id')}, union_doc_id={p.get('union_doc_id')}, source_doc_id={p.get('source_doc_id')}") | |
| has_match = any(rid in rels for rid in ranking[:10]) | |
| print(f"[EVAL] Any match in top-10? {has_match}") | |
| metrics_collectors["ndcg@5"].append(ndcg_at_k(ranking, rels, k=5)) | |
| metrics_collectors["ndcg@10"].append(ndcg_at_k(ranking, rels, k=10)) | |
| metrics_collectors["recall@5"].append(recall_at_k(ranking, rels, k=5)) | |
| metrics_collectors["recall@10"].append(recall_at_k(ranking, rels, k=10)) | |
| metrics_collectors["mrr@5"].append(mrr_at_k(ranking, rels, k=5)) | |
| metrics_collectors["mrr@10"].append(mrr_at_k(ranking, rels, k=10)) | |
| progress = (i + 1) / total_queries | |
| progress_bar.progress(progress) | |
| eval_status.info(f"π― {i+1}/{total_queries} ({int(progress*100)}%) β latency: {np.mean(latencies):.0f}ms") | |
| log_interval = max(10, total_queries // 10) | |
| if (i + 1) % log_interval == 0 and i > 0: | |
| cur_ndcg = np.mean(metrics_collectors["ndcg@10"]) | |
| cur_lat = np.mean(latencies[1:]) if len(latencies) > 1 else latencies[0] | |
| log_lines.append(f"[{i+1}/{total_queries}] NDCG@10={cur_ndcg:.4f}, lat={cur_lat:.0f}ms") | |
| log_area.code("\n".join(log_lines[-10:]), language="text") | |
| print(f"[EVAL] Progress {i+1}/{total_queries}: NDCG@10={cur_ndcg:.4f}, lat={cur_lat:.0f}ms") | |
| progress_bar.progress(1.0) | |
| eval_status.success(f"β Evaluation complete! ({total_queries} queries)") | |
| with results_container: | |
| st.markdown("##### π Results") | |
| p95_latency = float(np.percentile(latencies, 95)) | |
| eval_time_s = sum(latencies) / 1000 | |
| qps = total_queries / eval_time_s if eval_time_s > 0 else 0 | |
| final_metrics = { | |
| "ndcg@5": float(np.mean(metrics_collectors["ndcg@5"])), | |
| "ndcg@10": float(np.mean(metrics_collectors["ndcg@10"])), | |
| "recall@5": float(np.mean(metrics_collectors["recall@5"])), | |
| "recall@10": float(np.mean(metrics_collectors["recall@10"])), | |
| "mrr@5": float(np.mean(metrics_collectors["mrr@5"])), | |
| "mrr@10": float(np.mean(metrics_collectors["mrr@10"])), | |
| "avg_latency_ms": float(np.mean(latencies)), | |
| "p95_latency_ms": p95_latency, | |
| "qps": qps, | |
| "eval_time_s": eval_time_s, | |
| "num_queries": total_queries, | |
| } | |
| print("=" * 60) | |
| print("[EVAL] FINAL RESULTS:") | |
| print(f"[EVAL] NDCG@5: {final_metrics['ndcg@5']:.4f}") | |
| print(f"[EVAL] NDCG@10: {final_metrics['ndcg@10']:.4f}") | |
| print(f"[EVAL] Recall@5: {final_metrics['recall@5']:.4f}") | |
| print(f"[EVAL] Recall@10: {final_metrics['recall@10']:.4f}") | |
| print(f"[EVAL] MRR@5: {final_metrics['mrr@5']:.4f}") | |
| print(f"[EVAL] MRR@10: {final_metrics['mrr@10']:.4f}") | |
| print(f"[EVAL] Avg Latency: {final_metrics['avg_latency_ms']:.1f}ms") | |
| print(f"[EVAL] P95 Latency: {final_metrics['p95_latency_ms']:.1f}ms") | |
| print(f"[EVAL] QPS: {final_metrics['qps']:.2f}") | |
| print(f"[EVAL] Queries: {final_metrics['num_queries']}") | |
| print("=" * 60) | |
| st.markdown("**Retrieval Metrics**") | |
| c1, c2, c3 = st.columns(3) | |
| with c1: | |
| st.metric("NDCG@5", f"{final_metrics['ndcg@5']:.4f}") | |
| st.metric("NDCG@10", f"{final_metrics['ndcg@10']:.4f}") | |
| with c2: | |
| st.metric("Recall@5", f"{final_metrics['recall@5']:.4f}") | |
| st.metric("Recall@10", f"{final_metrics['recall@10']:.4f}") | |
| with c3: | |
| st.metric("MRR@5", f"{final_metrics['mrr@5']:.4f}") | |
| st.metric("MRR@10", f"{final_metrics['mrr@10']:.4f}") | |
| st.markdown("**Performance**") | |
| c4, c5, c6, c7 = st.columns(4) | |
| c4.metric("Avg Latency", f"{final_metrics['avg_latency_ms']:.0f}ms") | |
| c5.metric("P95 Latency", f"{final_metrics['p95_latency_ms']:.0f}ms") | |
| c6.metric("QPS", f"{final_metrics['qps']:.2f}") | |
| c7.metric("Eval Time", f"{final_metrics['eval_time_s']:.1f}s") | |
| with st.expander("π Full Results JSON"): | |
| st.json(final_metrics) | |
| detailed_report = { | |
| "generated_at": datetime.now().isoformat(), | |
| "config": { | |
| "collection": collection, | |
| "model": model, | |
| "datasets": datasets, | |
| "mode": mode, | |
| "top_k": top_k, | |
| "evaluation_scope": config.get("evaluation_scope", "union"), | |
| "prefer_grpc": prefer_grpc, | |
| "torch_dtype": torch_dtype, | |
| "max_queries": config.get("max_queries"), | |
| }, | |
| "retrieval_metrics": { | |
| "ndcg@5": final_metrics["ndcg@5"], | |
| "ndcg@10": final_metrics["ndcg@10"], | |
| "recall@5": final_metrics["recall@5"], | |
| "recall@10": final_metrics["recall@10"], | |
| "mrr@5": final_metrics["mrr@5"], | |
| "mrr@10": final_metrics["mrr@10"], | |
| }, | |
| "performance": { | |
| "avg_latency_ms": final_metrics["avg_latency_ms"], | |
| "p95_latency_ms": final_metrics["p95_latency_ms"], | |
| "qps": final_metrics["qps"], | |
| "eval_time_s": final_metrics["eval_time_s"], | |
| "num_queries": final_metrics["num_queries"], | |
| }, | |
| } | |
| if mode == "two_stage": | |
| detailed_report["config"]["stage1_mode"] = stage1_mode | |
| detailed_report["config"]["prefetch_k"] = prefetch_k | |
| elif mode == "three_stage": | |
| detailed_report["config"]["stage1_k"] = stage1_k | |
| detailed_report["config"]["stage2_k"] = stage2_k | |
| if evaluation_scope == "per_dataset" and metrics_by_dataset: | |
| detailed_report["metrics_by_dataset"] = metrics_by_dataset | |
| st.markdown("**Per-Dataset Results**") | |
| for ds_name, ds_metrics in metrics_by_dataset.items(): | |
| ds_short = ds_name.split("/")[-1] | |
| with st.expander(f"π {ds_short}"): | |
| dc1, dc2, dc3, dc4 = st.columns(4) | |
| dc1.metric("NDCG@10", f"{ds_metrics['ndcg@10']:.4f}") | |
| dc2.metric("Recall@10", f"{ds_metrics['recall@10']:.4f}") | |
| dc3.metric("MRR@10", f"{ds_metrics['mrr@10']:.4f}") | |
| dc4.metric("Latency", f"{ds_metrics['avg_latency_ms']:.0f}ms") | |
| report_json = json.dumps(detailed_report, indent=2) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"eval_report__{collection}__{mode}__{timestamp}.json" | |
| st.download_button( | |
| label="π₯ Download Detailed Report", | |
| data=report_json, | |
| file_name=filename, | |
| mime="application/json", | |
| use_container_width=True, | |
| ) | |
| st.session_state["last_eval_metrics"] = final_metrics | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "not configured in this collection" in error_msg: | |
| vector_name = error_msg.split("name ")[-1].split(" is")[0] if "name " in error_msg else "unknown" | |
| st.error(f"β **Collection Mismatch**: Vector `{vector_name}` not found in collection `{collection}`") | |
| st.warning(f""" | |
| **The selected mode `{mode}` requires vectors that don't exist in this collection.** | |
| **Suggestions:** | |
| - Try `single_full` mode (works with basic collections) | |
| - Use a collection indexed with the Visual RAG Toolkit | |
| - Check that the collection has the required vector types for `{mode}` mode | |
| """) | |
| else: | |
| st.error(f"β Error: {e}") | |
| with st.expander("π Full Error Details"): | |
| st.code(traceback.format_exc(), language="text") | |