AbstractPhil's picture
Rename benchmark_post.py to benchmarks/benchmark_post.py
16f0dab verified
# ============================================================================
# BENCHMARK: geolip-captionbert-8192 vs Individual BERTs
#
# Loads model from: AbstractPhil/geolip-captionbert-8192
#
# Tests:
# 1. STS-B β€” Spearman correlation with human similarity judgments
# 2. SICK-R β€” Compositional/syntactic similarity
# 3. MRPC β€” Paraphrase detection (cosine threshold)
# 4. Caption retrieval β€” self-retrieval on CC12M subset
#
# Compares against all 5 consensus teachers + sentence-transformers baseline
# ============================================================================
import os
import json
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("=" * 65)
print("BENCHMARK: geolip-captionbert-8192")
print("=" * 65)
print(f" Device: {DEVICE}")
# ══════════════════════════════════════════════════════════════════
# MODEL: CaptionEncoder (must match HF repo)
# ══════════════════════════════════════════════════════════════════
class CaptionEncoder(nn.Module):
def __init__(self, vocab_size=30522, max_len=8192, d_model=384,
n_heads=6, n_layers=6, d_ff=1536, output_dim=768,
dropout=0.1, pad_token_id=0):
super().__init__()
self.pad_token_id = pad_token_id
self.d_model = d_model
self.max_len = max_len
self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
self.pos_emb = nn.Embedding(max_len, d_model)
self.emb_norm = nn.LayerNorm(d_model)
self.emb_drop = nn.Dropout(dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=d_ff,
dropout=dropout, activation="gelu", batch_first=True,
norm_first=True)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
self.output_proj = nn.Sequential(
nn.Linear(d_model, d_model), nn.GELU(),
nn.LayerNorm(d_model), nn.Linear(d_model, output_dim))
def forward(self, input_ids, attention_mask=None):
B, L = input_ids.shape
positions = torch.arange(L, device=input_ids.device).unsqueeze(0)
x = self.token_emb(input_ids) + self.pos_emb(positions)
x = self.emb_drop(self.emb_norm(x))
if attention_mask is not None:
kpm = ~attention_mask.bool()
else:
kpm = (input_ids == self.pad_token_id)
x = self.encoder(x, src_key_padding_mask=kpm)
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).float()
else:
mask = (~kpm).unsqueeze(-1).float()
pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
return F.normalize(self.output_proj(pooled), dim=-1)
# ══════════════════════════════════════════════════════════════════
# LOAD BENCHMARKS
# ══════════════════════════════════════════════════════════════════
def load_stsb():
from datasets import load_dataset
ds = load_dataset("mteb/stsbenchmark-sts", split="test")
pairs = [{"sent1": r["sentence1"], "sent2": r["sentence2"], "score": r["score"]} for r in ds]
print(f" STS-B test: {len(pairs)} pairs")
return pairs
def load_sick():
from datasets import load_dataset
ds = load_dataset("mteb/sickr-sts", split="test")
pairs = [{"sent1": r["sentence1"], "sent2": r["sentence2"], "score": r["score"]} for r in ds]
print(f" SICK-R test: {len(pairs)} pairs")
return pairs
def load_mrpc():
from datasets import load_dataset
ds = load_dataset("glue", "mrpc", split="test")
pairs = [{"sent1": r["sentence1"], "sent2": r["sentence2"], "label": r["label"]} for r in ds]
print(f" MRPC test: {len(pairs)} pairs")
return pairs
def load_caption_retrieval(n=5000):
from datasets import load_dataset
print(f" Loading CC12M captions for retrieval (n={n})...")
ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext",
split="train", streaming=True)
captions = []
for row in ds:
cap = row.get("caption_llava", "")
if isinstance(cap, str) and len(cap) > 50:
captions.append(cap)
if len(captions) >= n:
break
# Use last 1000 as query, rest as corpus
queries = captions[-1000:]
corpus = captions[:-1000]
print(f" Corpus: {len(corpus)}, Queries: {len(queries)}")
return corpus, queries
# ══════════════════════════════════════════════════════════════════
# ENCODING
# ══════════════════════════════════════════════════════════════════
@torch.no_grad()
def encode_hf(model, tokenizer, texts, batch_size=128, max_len=512):
all_emb = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
inputs = tokenizer(batch, max_length=max_len, padding=True,
truncation=True, return_tensors="pt").to(DEVICE)
out = model(**inputs)
mask = inputs.attention_mask.unsqueeze(-1).float()
pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1)
all_emb.append(F.normalize(pooled, dim=-1).cpu())
return torch.cat(all_emb)
@torch.no_grad()
def encode_student(model, tokenizer, texts, batch_size=128, max_len=512):
all_emb = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
inputs = tokenizer(batch, max_length=max_len, padding="max_length",
truncation=True, return_tensors="pt").to(DEVICE)
emb = model(inputs["input_ids"], inputs["attention_mask"])
all_emb.append(emb.cpu())
return torch.cat(all_emb)
# ══════════════════════════════════════════════════════════════════
# EVALUATION METRICS
# ══════════════════════════════════════════════════════════════════
def eval_sts(pairs, emb1, emb2):
cosines = F.cosine_similarity(emb1, emb2, dim=-1).numpy()
gold = np.array([p["score"] for p in pairs])
return {
"spearman": float(spearmanr(cosines, gold).statistic),
"pearson": float(pearsonr(cosines, gold).statistic),
"cos_mean": float(cosines.mean()),
}
def eval_mrpc(pairs, emb1, emb2):
cosines = F.cosine_similarity(emb1, emb2, dim=-1).numpy()
labels = np.array([p["label"] for p in pairs])
# Find optimal threshold
best_f1, best_thresh = 0, 0.5
for thresh in np.arange(0.5, 1.0, 0.01):
preds = (cosines > thresh).astype(int)
f1 = f1_score(labels, preds, zero_division=0)
if f1 > best_f1:
best_f1 = f1
best_thresh = thresh
preds = (cosines > best_thresh).astype(int)
return {
"f1": float(best_f1),
"accuracy": float(accuracy_score(labels, preds)),
"threshold": float(best_thresh),
}
def eval_retrieval(query_emb, corpus_emb, k_vals=(1, 5, 10)):
# Query embeddings should retrieve themselves from corpus+query pool
sim = query_emb @ corpus_emb.T
results = {}
N = query_emb.shape[0]
for k in k_vals:
topk = sim.topk(min(k, corpus_emb.shape[0]), dim=1).indices
# No ground truth matching β€” measure diversity/spread
results[f"mean_top{k}_cos"] = sim.topk(k, dim=1).values.mean().item()
# Self-similarity
self_sim = query_emb @ query_emb.T
self_sim.fill_diagonal_(0)
results["self_cos_mean"] = self_sim.mean().item()
results["self_cos_max"] = self_sim.max().item()
return results
# ══════════════════════════════════════════════════════════════════
# MAIN
# ══════════════════════════════════════════════════════════════════
def run():
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download
# ── Load benchmarks ──
print(f"\n{'='*65}")
print("LOADING BENCHMARKS")
print(f"{'='*65}")
stsb = load_stsb()
sick = load_sick()
mrpc = load_mrpc()
ret_corpus, ret_queries = load_caption_retrieval(5000)
stsb_s1 = [p["sent1"] for p in stsb]
stsb_s2 = [p["sent2"] for p in stsb]
sick_s1 = [p["sent1"] for p in sick]
sick_s2 = [p["sent2"] for p in sick]
mrpc_s1 = [p["sent1"] for p in mrpc]
mrpc_s2 = [p["sent2"] for p in mrpc]
results = {}
# ── Load student from HuggingFace ──
print(f"\n{'='*65}")
print("LOADING: geolip-captionbert-8192")
print(f"{'='*65}")
repo_id = "AbstractPhil/geolip-captionbert-8192"
ckpt_path = hf_hub_download(repo_id=repo_id, filename="best_model.pt")
print(f" Downloaded: {ckpt_path}")
student_tok = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
student = CaptionEncoder(
vocab_size=student_tok.vocab_size,
max_len=8192, d_model=384, n_heads=6, n_layers=6,
d_ff=1536, output_dim=768, dropout=0.0,
pad_token_id=student_tok.pad_token_id).to(DEVICE)
student.load_state_dict(
torch.load(ckpt_path, weights_only=True, map_location=DEVICE))
student.eval()
n_params = sum(p.numel() for p in student.parameters())
print(f" Parameters: {n_params:,}")
# Encode
print(" Encoding STS-B...")
s_stsb1 = encode_student(student, student_tok, stsb_s1)
s_stsb2 = encode_student(student, student_tok, stsb_s2)
print(" Encoding SICK-R...")
s_sick1 = encode_student(student, student_tok, sick_s1)
s_sick2 = encode_student(student, student_tok, sick_s2)
print(" Encoding MRPC...")
s_mrpc1 = encode_student(student, student_tok, mrpc_s1)
s_mrpc2 = encode_student(student, student_tok, mrpc_s2)
print(" Encoding captions...")
s_corpus = encode_student(student, student_tok, ret_corpus)
s_queries = encode_student(student, student_tok, ret_queries)
r_stsb = eval_sts(stsb, s_stsb1, s_stsb2)
r_sick = eval_sts(sick, s_sick1, s_sick2)
r_mrpc = eval_mrpc(mrpc, s_mrpc1, s_mrpc2)
r_ret = eval_retrieval(s_queries, s_corpus)
results["captionbert"] = {
"stsb": r_stsb, "sick": r_sick, "mrpc": r_mrpc,
"retrieval": r_ret, "params": n_params,
}
print(f" STS-B: spearman={r_stsb['spearman']:.4f} pearson={r_stsb['pearson']:.4f}")
print(f" SICK-R: spearman={r_sick['spearman']:.4f} pearson={r_sick['pearson']:.4f}")
print(f" MRPC: f1={r_mrpc['f1']:.4f} acc={r_mrpc['accuracy']:.4f} thresh={r_mrpc['threshold']:.2f}")
print(f" Caption self-cos: mean={r_ret['self_cos_mean']:.4f} max={r_ret['self_cos_max']:.4f}")
del student
gc.collect()
torch.cuda.empty_cache()
# ── Evaluate teachers ──
teachers = [
("google-bert/bert-base-uncased", "bert-base"),
("answerdotai/ModernBERT-base", "modern-bert"),
("FacebookAI/roberta-base", "roberta"),
("albert/albert-base-v2", "albert"),
("distilbert/distilbert-base-uncased", "distilbert"),
]
for model_name, short in teachers:
print(f"\n{'='*65}")
print(f"EVALUATING: {short}")
print(f"{'='*65}")
model = AutoModel.from_pretrained(model_name).to(DEVICE).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
n_p = sum(p.numel() for p in model.parameters())
print(f" Parameters: {n_p:,}")
print(" Encoding STS-B...")
e1 = encode_hf(model, tokenizer, stsb_s1)
e2 = encode_hf(model, tokenizer, stsb_s2)
r_stsb = eval_sts(stsb, e1, e2)
print(" Encoding SICK-R...")
e1 = encode_hf(model, tokenizer, sick_s1)
e2 = encode_hf(model, tokenizer, sick_s2)
r_sick = eval_sts(sick, e1, e2)
print(" Encoding MRPC...")
e1 = encode_hf(model, tokenizer, mrpc_s1)
e2 = encode_hf(model, tokenizer, mrpc_s2)
r_mrpc = eval_mrpc(mrpc, e1, e2)
print(" Encoding captions...")
eq = encode_hf(model, tokenizer, ret_queries)
ec = encode_hf(model, tokenizer, ret_corpus)
r_ret = eval_retrieval(eq, ec)
results[short] = {
"stsb": r_stsb, "sick": r_sick, "mrpc": r_mrpc,
"retrieval": r_ret, "params": n_p,
}
print(f" STS-B: spearman={r_stsb['spearman']:.4f}")
print(f" SICK-R: spearman={r_sick['spearman']:.4f}")
print(f" MRPC: f1={r_mrpc['f1']:.4f}")
del model, tokenizer
gc.collect()
torch.cuda.empty_cache()
# ══════════════════════════════════════════════════════════════
# SUMMARY
# ══════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("FULL BENCHMARK SUMMARY")
print(f"{'='*65}")
print(f"\n {'Model':<20} {'Params':>10} {'STS-B ρ':>9} {'SICK-R ρ':>9} {'MRPC F1':>9}")
print(f" {'-'*57}")
sorted_r = sorted(results.items(),
key=lambda x: x[1]["stsb"]["spearman"], reverse=True)
for name, r in sorted_r:
marker = " β˜…" if name == "captionbert" else ""
print(f" {name:<20} {r['params']:>10,} "
f"{r['stsb']['spearman']:>9.4f} "
f"{r['sick']['spearman']:>9.4f} "
f"{r['mrpc']['f1']:>9.4f}{marker}")
# Detailed captionbert results
cb = results["captionbert"]
print(f"\n geolip-captionbert-8192 detailed:")
print(f" STS-B: spearman={cb['stsb']['spearman']:.4f} pearson={cb['stsb']['pearson']:.4f} mean_cos={cb['stsb']['cos_mean']:.4f}")
print(f" SICK-R: spearman={cb['sick']['spearman']:.4f} pearson={cb['sick']['pearson']:.4f} mean_cos={cb['sick']['cos_mean']:.4f}")
print(f" MRPC: f1={cb['mrpc']['f1']:.4f} acc={cb['mrpc']['accuracy']:.4f} threshold={cb['mrpc']['threshold']:.2f}")
print(f" Caption retrieval:")
for k, v in cb["retrieval"].items():
print(f" {k}: {v:.4f}")
# Rankings
print(f"\n Rankings:")
for bench in ["stsb", "sick"]:
ranked = sorted(results.items(),
key=lambda x: x[1][bench]["spearman"], reverse=True)
pos = next(i for i, (n, _) in enumerate(ranked) if n == "captionbert") + 1
print(f" {bench.upper()}: #{pos}/{len(ranked)}")
ranked_mrpc = sorted(results.items(),
key=lambda x: x[1]["mrpc"]["f1"], reverse=True)
pos = next(i for i, (n, _) in enumerate(ranked_mrpc) if n == "captionbert") + 1
print(f" MRPC: #{pos}/{len(ranked_mrpc)}")
# vs best teacher
best_name = max((n for n in results if n != "captionbert"),
key=lambda n: results[n]["stsb"]["spearman"])
best_stsb = results[best_name]["stsb"]["spearman"]
student_stsb = results["captionbert"]["stsb"]["spearman"]
print(f"\n vs Best teacher ({best_name}):")
print(f" STS-B gap: {student_stsb - best_stsb:+.4f}")
print(f" Param ratio: {results[best_name]['params'] / results['captionbert']['params']:.1f}Γ—")
# Save
save_path = "benchmark_captionbert_8192.json"
with open(save_path, "w") as f:
json.dump(results, f, indent=2, default=str)
print(f"\n Saved to {save_path}")
print(f"\n{'='*65}")
print("DONE")
print(f"{'='*65}")
if __name__ == "__main__":
run()