# ============================================================================ # BENCHMARK: geolip-captionbert-8192 vs Individual BERTs # # Loads model from: AbstractPhil/geolip-captionbert-8192 # # Tests: # 1. STS-B — Spearman correlation with human similarity judgments # 2. SICK-R — Compositional/syntactic similarity # 3. MRPC — Paraphrase detection (cosine threshold) # 4. Caption retrieval — self-retrieval on CC12M subset # # Compares against all 5 consensus teachers + sentence-transformers baseline # ============================================================================ import os import json import gc import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from scipy.stats import spearmanr, pearsonr from sklearn.metrics import accuracy_score, f1_score from tqdm import tqdm DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print("=" * 65) print("BENCHMARK: geolip-captionbert-8192") print("=" * 65) print(f" Device: {DEVICE}") # ══════════════════════════════════════════════════════════════════ # MODEL: CaptionEncoder (must match HF repo) # ══════════════════════════════════════════════════════════════════ class CaptionEncoder(nn.Module): def __init__(self, vocab_size=30522, max_len=8192, d_model=384, n_heads=6, n_layers=6, d_ff=1536, output_dim=768, dropout=0.1, pad_token_id=0): super().__init__() self.pad_token_id = pad_token_id self.d_model = d_model self.max_len = max_len self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id) self.pos_emb = nn.Embedding(max_len, d_model) self.emb_norm = nn.LayerNorm(d_model) self.emb_drop = nn.Dropout(dropout) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, dropout=dropout, activation="gelu", batch_first=True, norm_first=True) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) self.output_proj = nn.Sequential( nn.Linear(d_model, d_model), nn.GELU(), nn.LayerNorm(d_model), nn.Linear(d_model, output_dim)) def forward(self, input_ids, attention_mask=None): B, L = input_ids.shape positions = torch.arange(L, device=input_ids.device).unsqueeze(0) x = self.token_emb(input_ids) + self.pos_emb(positions) x = self.emb_drop(self.emb_norm(x)) if attention_mask is not None: kpm = ~attention_mask.bool() else: kpm = (input_ids == self.pad_token_id) x = self.encoder(x, src_key_padding_mask=kpm) if attention_mask is not None: mask = attention_mask.unsqueeze(-1).float() else: mask = (~kpm).unsqueeze(-1).float() pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1) return F.normalize(self.output_proj(pooled), dim=-1) # ══════════════════════════════════════════════════════════════════ # LOAD BENCHMARKS # ══════════════════════════════════════════════════════════════════ def load_stsb(): from datasets import load_dataset ds = load_dataset("mteb/stsbenchmark-sts", split="test") pairs = [{"sent1": r["sentence1"], "sent2": r["sentence2"], "score": r["score"]} for r in ds] print(f" STS-B test: {len(pairs)} pairs") return pairs def load_sick(): from datasets import load_dataset ds = load_dataset("mteb/sickr-sts", split="test") pairs = [{"sent1": r["sentence1"], "sent2": r["sentence2"], "score": r["score"]} for r in ds] print(f" SICK-R test: {len(pairs)} pairs") return pairs def load_mrpc(): from datasets import load_dataset ds = load_dataset("glue", "mrpc", split="test") pairs = [{"sent1": r["sentence1"], "sent2": r["sentence2"], "label": r["label"]} for r in ds] print(f" MRPC test: {len(pairs)} pairs") return pairs def load_caption_retrieval(n=5000): from datasets import load_dataset print(f" Loading CC12M captions for retrieval (n={n})...") ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext", split="train", streaming=True) captions = [] for row in ds: cap = row.get("caption_llava", "") if isinstance(cap, str) and len(cap) > 50: captions.append(cap) if len(captions) >= n: break # Use last 1000 as query, rest as corpus queries = captions[-1000:] corpus = captions[:-1000] print(f" Corpus: {len(corpus)}, Queries: {len(queries)}") return corpus, queries # ══════════════════════════════════════════════════════════════════ # ENCODING # ══════════════════════════════════════════════════════════════════ @torch.no_grad() def encode_hf(model, tokenizer, texts, batch_size=128, max_len=512): all_emb = [] for i in range(0, len(texts), batch_size): batch = texts[i:i+batch_size] inputs = tokenizer(batch, max_length=max_len, padding=True, truncation=True, return_tensors="pt").to(DEVICE) out = model(**inputs) mask = inputs.attention_mask.unsqueeze(-1).float() pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1) all_emb.append(F.normalize(pooled, dim=-1).cpu()) return torch.cat(all_emb) @torch.no_grad() def encode_student(model, tokenizer, texts, batch_size=128, max_len=512): all_emb = [] for i in range(0, len(texts), batch_size): batch = texts[i:i+batch_size] inputs = tokenizer(batch, max_length=max_len, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE) emb = model(inputs["input_ids"], inputs["attention_mask"]) all_emb.append(emb.cpu()) return torch.cat(all_emb) # ══════════════════════════════════════════════════════════════════ # EVALUATION METRICS # ══════════════════════════════════════════════════════════════════ def eval_sts(pairs, emb1, emb2): cosines = F.cosine_similarity(emb1, emb2, dim=-1).numpy() gold = np.array([p["score"] for p in pairs]) return { "spearman": float(spearmanr(cosines, gold).statistic), "pearson": float(pearsonr(cosines, gold).statistic), "cos_mean": float(cosines.mean()), } def eval_mrpc(pairs, emb1, emb2): cosines = F.cosine_similarity(emb1, emb2, dim=-1).numpy() labels = np.array([p["label"] for p in pairs]) # Find optimal threshold best_f1, best_thresh = 0, 0.5 for thresh in np.arange(0.5, 1.0, 0.01): preds = (cosines > thresh).astype(int) f1 = f1_score(labels, preds, zero_division=0) if f1 > best_f1: best_f1 = f1 best_thresh = thresh preds = (cosines > best_thresh).astype(int) return { "f1": float(best_f1), "accuracy": float(accuracy_score(labels, preds)), "threshold": float(best_thresh), } def eval_retrieval(query_emb, corpus_emb, k_vals=(1, 5, 10)): # Query embeddings should retrieve themselves from corpus+query pool sim = query_emb @ corpus_emb.T results = {} N = query_emb.shape[0] for k in k_vals: topk = sim.topk(min(k, corpus_emb.shape[0]), dim=1).indices # No ground truth matching — measure diversity/spread results[f"mean_top{k}_cos"] = sim.topk(k, dim=1).values.mean().item() # Self-similarity self_sim = query_emb @ query_emb.T self_sim.fill_diagonal_(0) results["self_cos_mean"] = self_sim.mean().item() results["self_cos_max"] = self_sim.max().item() return results # ══════════════════════════════════════════════════════════════════ # MAIN # ══════════════════════════════════════════════════════════════════ def run(): from transformers import AutoModel, AutoTokenizer from huggingface_hub import hf_hub_download # ── Load benchmarks ── print(f"\n{'='*65}") print("LOADING BENCHMARKS") print(f"{'='*65}") stsb = load_stsb() sick = load_sick() mrpc = load_mrpc() ret_corpus, ret_queries = load_caption_retrieval(5000) stsb_s1 = [p["sent1"] for p in stsb] stsb_s2 = [p["sent2"] for p in stsb] sick_s1 = [p["sent1"] for p in sick] sick_s2 = [p["sent2"] for p in sick] mrpc_s1 = [p["sent1"] for p in mrpc] mrpc_s2 = [p["sent2"] for p in mrpc] results = {} # ── Load student from HuggingFace ── print(f"\n{'='*65}") print("LOADING: geolip-captionbert-8192") print(f"{'='*65}") repo_id = "AbstractPhil/geolip-captionbert-8192" ckpt_path = hf_hub_download(repo_id=repo_id, filename="best_model.pt") print(f" Downloaded: {ckpt_path}") student_tok = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") student = CaptionEncoder( vocab_size=student_tok.vocab_size, max_len=8192, d_model=384, n_heads=6, n_layers=6, d_ff=1536, output_dim=768, dropout=0.0, pad_token_id=student_tok.pad_token_id).to(DEVICE) student.load_state_dict( torch.load(ckpt_path, weights_only=True, map_location=DEVICE)) student.eval() n_params = sum(p.numel() for p in student.parameters()) print(f" Parameters: {n_params:,}") # Encode print(" Encoding STS-B...") s_stsb1 = encode_student(student, student_tok, stsb_s1) s_stsb2 = encode_student(student, student_tok, stsb_s2) print(" Encoding SICK-R...") s_sick1 = encode_student(student, student_tok, sick_s1) s_sick2 = encode_student(student, student_tok, sick_s2) print(" Encoding MRPC...") s_mrpc1 = encode_student(student, student_tok, mrpc_s1) s_mrpc2 = encode_student(student, student_tok, mrpc_s2) print(" Encoding captions...") s_corpus = encode_student(student, student_tok, ret_corpus) s_queries = encode_student(student, student_tok, ret_queries) r_stsb = eval_sts(stsb, s_stsb1, s_stsb2) r_sick = eval_sts(sick, s_sick1, s_sick2) r_mrpc = eval_mrpc(mrpc, s_mrpc1, s_mrpc2) r_ret = eval_retrieval(s_queries, s_corpus) results["captionbert"] = { "stsb": r_stsb, "sick": r_sick, "mrpc": r_mrpc, "retrieval": r_ret, "params": n_params, } print(f" STS-B: spearman={r_stsb['spearman']:.4f} pearson={r_stsb['pearson']:.4f}") print(f" SICK-R: spearman={r_sick['spearman']:.4f} pearson={r_sick['pearson']:.4f}") print(f" MRPC: f1={r_mrpc['f1']:.4f} acc={r_mrpc['accuracy']:.4f} thresh={r_mrpc['threshold']:.2f}") print(f" Caption self-cos: mean={r_ret['self_cos_mean']:.4f} max={r_ret['self_cos_max']:.4f}") del student gc.collect() torch.cuda.empty_cache() # ── Evaluate teachers ── teachers = [ ("google-bert/bert-base-uncased", "bert-base"), ("answerdotai/ModernBERT-base", "modern-bert"), ("FacebookAI/roberta-base", "roberta"), ("albert/albert-base-v2", "albert"), ("distilbert/distilbert-base-uncased", "distilbert"), ] for model_name, short in teachers: print(f"\n{'='*65}") print(f"EVALUATING: {short}") print(f"{'='*65}") model = AutoModel.from_pretrained(model_name).to(DEVICE).eval() tokenizer = AutoTokenizer.from_pretrained(model_name) n_p = sum(p.numel() for p in model.parameters()) print(f" Parameters: {n_p:,}") print(" Encoding STS-B...") e1 = encode_hf(model, tokenizer, stsb_s1) e2 = encode_hf(model, tokenizer, stsb_s2) r_stsb = eval_sts(stsb, e1, e2) print(" Encoding SICK-R...") e1 = encode_hf(model, tokenizer, sick_s1) e2 = encode_hf(model, tokenizer, sick_s2) r_sick = eval_sts(sick, e1, e2) print(" Encoding MRPC...") e1 = encode_hf(model, tokenizer, mrpc_s1) e2 = encode_hf(model, tokenizer, mrpc_s2) r_mrpc = eval_mrpc(mrpc, e1, e2) print(" Encoding captions...") eq = encode_hf(model, tokenizer, ret_queries) ec = encode_hf(model, tokenizer, ret_corpus) r_ret = eval_retrieval(eq, ec) results[short] = { "stsb": r_stsb, "sick": r_sick, "mrpc": r_mrpc, "retrieval": r_ret, "params": n_p, } print(f" STS-B: spearman={r_stsb['spearman']:.4f}") print(f" SICK-R: spearman={r_sick['spearman']:.4f}") print(f" MRPC: f1={r_mrpc['f1']:.4f}") del model, tokenizer gc.collect() torch.cuda.empty_cache() # ══════════════════════════════════════════════════════════════ # SUMMARY # ══════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("FULL BENCHMARK SUMMARY") print(f"{'='*65}") print(f"\n {'Model':<20} {'Params':>10} {'STS-B ρ':>9} {'SICK-R ρ':>9} {'MRPC F1':>9}") print(f" {'-'*57}") sorted_r = sorted(results.items(), key=lambda x: x[1]["stsb"]["spearman"], reverse=True) for name, r in sorted_r: marker = " ★" if name == "captionbert" else "" print(f" {name:<20} {r['params']:>10,} " f"{r['stsb']['spearman']:>9.4f} " f"{r['sick']['spearman']:>9.4f} " f"{r['mrpc']['f1']:>9.4f}{marker}") # Detailed captionbert results cb = results["captionbert"] print(f"\n geolip-captionbert-8192 detailed:") print(f" STS-B: spearman={cb['stsb']['spearman']:.4f} pearson={cb['stsb']['pearson']:.4f} mean_cos={cb['stsb']['cos_mean']:.4f}") print(f" SICK-R: spearman={cb['sick']['spearman']:.4f} pearson={cb['sick']['pearson']:.4f} mean_cos={cb['sick']['cos_mean']:.4f}") print(f" MRPC: f1={cb['mrpc']['f1']:.4f} acc={cb['mrpc']['accuracy']:.4f} threshold={cb['mrpc']['threshold']:.2f}") print(f" Caption retrieval:") for k, v in cb["retrieval"].items(): print(f" {k}: {v:.4f}") # Rankings print(f"\n Rankings:") for bench in ["stsb", "sick"]: ranked = sorted(results.items(), key=lambda x: x[1][bench]["spearman"], reverse=True) pos = next(i for i, (n, _) in enumerate(ranked) if n == "captionbert") + 1 print(f" {bench.upper()}: #{pos}/{len(ranked)}") ranked_mrpc = sorted(results.items(), key=lambda x: x[1]["mrpc"]["f1"], reverse=True) pos = next(i for i, (n, _) in enumerate(ranked_mrpc) if n == "captionbert") + 1 print(f" MRPC: #{pos}/{len(ranked_mrpc)}") # vs best teacher best_name = max((n for n in results if n != "captionbert"), key=lambda n: results[n]["stsb"]["spearman"]) best_stsb = results[best_name]["stsb"]["spearman"] student_stsb = results["captionbert"]["stsb"]["spearman"] print(f"\n vs Best teacher ({best_name}):") print(f" STS-B gap: {student_stsb - best_stsb:+.4f}") print(f" Param ratio: {results[best_name]['params'] / results['captionbert']['params']:.1f}×") # Save save_path = "benchmark_captionbert_8192.json" with open(save_path, "w") as f: json.dump(results, f, indent=2, default=str) print(f"\n Saved to {save_path}") print(f"\n{'='*65}") print("DONE") print(f"{'='*65}") if __name__ == "__main__": run()