import json import os import struct from typing import Dict, List import pandas as pd import requests from huggingface_hub import HfApi, hf_hub_download # Required metrics for embedding evaluation REQUIRED_METRICS = [ "mteb_avg", "sts_spearman", "retrieval_top20", "msmarco_top10", ] def format_params(num_params): """Format parameter count as human-readable string.""" if num_params >= 1e9: return f"{num_params / 1e9:.1f}B" else: return f"{num_params / 1e6:.0f}M" def get_model_url(model_name): """Get the model URL from HuggingFace.""" return f"https://huggingface.co/{model_name}" def get_model_size(model_name): """Fetch model size from HuggingFace API.""" try: url = f"https://huggingface.co/api/models/{model_name}" response = requests.get(url, timeout=10) if response.status_code == 200: data = response.json() # Get safetensors size first, fallback to general parameters safetensors = data.get("safetensors") if safetensors and "total" in safetensors: num_params = safetensors["total"] return format_params(num_params) num_params = data.get("num_parameters") if num_params: return format_params(num_params) # Fallback: read actual param count from safetensors header num_params = get_params_from_safetensors(model_name) if num_params: return format_params(num_params) return None except Exception as e: print(f"Error fetching size for {model_name}: {e}") return None def get_params_from_safetensors(model_name): """Read safetensors header to get actual parameter count.""" try: tree_url = f"https://huggingface.co/api/models/{model_name}/tree/main" resp = requests.get(tree_url, timeout=10) if resp.status_code != 200: return None files = resp.json() safetensor_files = [f for f in files if f.get("path", "").endswith(".safetensors")] if not safetensor_files: return None total_params = 0 for sf in safetensor_files: file_url = f"https://huggingface.co/{model_name}/resolve/main/{sf['path']}" # Get header size (first 8 bytes) headers = {"Range": "bytes=0-7"} resp = requests.get(file_url, headers=headers, timeout=10, allow_redirects=True) if resp.status_code != 206 or len(resp.content) < 8: return None # Likely gated model header_size = struct.unpack(" List: if os.path.exists(self.model_infos_path): with open(self.model_infos_path) as f: return json.load(f) return [] def _save_model_infos(self): print("Saving model infos") with open(self.model_infos_path, "w") as f: json.dump(self.model_infos, f, indent=4) def get_embedding_benchmark_data(self) -> pd.DataFrame: """Fetch embedding benchmark results from HuggingFace models with ArmBench-TextEmbed tag.""" # Try to fetch new models from HuggingFace, but gracefully handle network errors try: models = self.api.list_models(filter="ArmBench-TextEmbed") model_names = {model["model_name"] for model in self.model_infos} repositories = [model.modelId for model in models] for repo_id in repositories: try: files = [f for f in self.api.list_repo_files(repo_id) if f == "results.json"] if not files: continue model_name = repo_id if model_name not in model_names: result_path = hf_hub_download(repo_id, filename="results.json") with open(result_path) as f: results = json.load(f) # Build model entry with metadata entry = { "model_name": model_name, "results": results } # Add model_url if not in results if "model_url" not in results: entry["model_url"] = get_model_url(model_name) # Add model_size if not in results if "model_size" not in results: model_size = get_model_size(model_name) if model_size: entry["model_size"] = model_size self.model_infos.append(entry) except Exception as e: print(f"Error loading {repo_id} - {e}") continue self._save_model_infos() except Exception as e: print(f"Failed to fetch from HuggingFace: {e}. Using local data.") # Build dataframe from results data = [] for model in self.model_infos: model_name = model["model_name"] results = model.get("results", {}) row = {"model_name": model_name} # Extract model metadata if "model_url" in model: row["model_url"] = model["model_url"] if "model_size" in model: row["model_size"] = model["model_size"] # Extract key metrics if "mteb_avg" in results: row["mteb_avg"] = results["mteb_avg"] if "sts_spearman" in results: row["sts_spearman"] = results["sts_spearman"] if "retrieval_top20" in results: row["retrieval_top20"] = results["retrieval_top20"] if "retrieval_translit_top20" in results: row["retrieval_translit_top20"] = results["retrieval_translit_top20"] if "msmarco_top10" in results: row["msmarco_top10"] = results["msmarco_top10"] if "msmarco_translit_top10" in results: row["msmarco_translit_top10"] = results["msmarco_translit_top10"] # Only add if at least one metric is present if len(row) > 1: data.append(row) return pd.DataFrame(data) def get_detailed_results(self) -> Dict: """Get all detailed results for MTEB, MS MARCO, STS, Retrieval, and translit benchmarks.""" mteb_data = [] msmarco_data = [] sts_data = [] retrieval_data = [] retrieval_translit_data = [] msmarco_translit_data = [] for model in self.model_infos: model_name = model["model_name"] results = model.get("results", {}) # MTEB detailed if "mteb_detailed" in results: row = {"model_name": model_name, **results["mteb_detailed"]} mteb_data.append(row) # MS MARCO detailed if "msmarco_detailed" in results: row = {"model_name": model_name, **results["msmarco_detailed"]} msmarco_data.append(row) # STS detailed if "sts_detailed" in results: row = {"model_name": model_name, **results["sts_detailed"]} sts_data.append(row) # Retrieval detailed if "retrieval_detailed" in results: row = {"model_name": model_name, **results["retrieval_detailed"]} retrieval_data.append(row) # Retrieval translit detailed if "retrieval_translit_detailed" in results: row = {"model_name": model_name, **results["retrieval_translit_detailed"]} retrieval_translit_data.append(row) # MS MARCO translit detailed if "msmarco_translit_detailed" in results: row = {"model_name": model_name, **results["msmarco_translit_detailed"]} msmarco_translit_data.append(row) return { "mteb": pd.DataFrame(mteb_data) if mteb_data else pd.DataFrame(), "msmarco": pd.DataFrame(msmarco_data) if msmarco_data else pd.DataFrame(), "sts": pd.DataFrame(sts_data) if sts_data else pd.DataFrame(), "retrieval": pd.DataFrame(retrieval_data) if retrieval_data else pd.DataFrame(), "retrieval_translit": pd.DataFrame(retrieval_translit_data) if retrieval_translit_data else pd.DataFrame(), "msmarco_translit": pd.DataFrame(msmarco_translit_data) if msmarco_translit_data else pd.DataFrame(), }