AbstractPhil's picture
Rename early_bench.py to benchmarks/early_bench.py
c39f077 verified
# ============================================================================
# BENCHMARK: Distilled Consensus Student vs Individual BERTs
#
# Tests:
# 1. STS-B (Semantic Textual Similarity Benchmark) β€” Spearman correlation
# 2. SICK-R (Sentences Involving Compositional Knowledge) β€” Spearman
# 3. Retrieval precision on held-out consensus targets
#
# Compares:
# - Distilled student (19-23M params, no pretrained weights)
# - BERT-base-uncased (110M params)
# - ModernBERT-base (149M params)
# - RoBERTa-base (125M params)
# - ALBERT-base-v2 (12M params)
# - DistilBERT-base (66M params)
#
# All models evaluated on mean-pooled embeddings β†’ cosine similarity
# ============================================================================
import os
import json
import torch
import torch.nn.functional as F
import numpy as np
from scipy.stats import spearmanr, pearsonr
from tqdm import tqdm
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("=" * 65)
print("BENCHMARK: Consensus Student vs Individual BERTs")
print("=" * 65)
# ══════════════════════════════════════════════════════════════════
# LOAD BENCHMARKS
# ══════════════════════════════════════════════════════════════════
def load_stsb():
"""Load STS-B test set."""
from datasets import load_dataset
ds = load_dataset("mteb/stsbenchmark-sts", split="test")
pairs = []
for row in ds:
pairs.append({
"sent1": row["sentence1"],
"sent2": row["sentence2"],
"score": row["score"],
})
print(f" STS-B test: {len(pairs)} pairs, scores {min(p['score'] for p in pairs):.1f}-{max(p['score'] for p in pairs):.1f}")
return pairs
def load_sick():
"""Load SICK-R test set."""
from datasets import load_dataset
ds = load_dataset("mteb/sickr-sts", split="test")
pairs = []
for row in ds:
pairs.append({
"sent1": row["sentence1"],
"sent2": row["sentence2"],
"score": row["score"],
})
print(f" SICK-R test: {len(pairs)} pairs, scores {min(p['score'] for p in pairs):.1f}-{max(p['score'] for p in pairs):.1f}")
return pairs
# ══════════════════════════════════════════════════════════════════
# ENCODE FUNCTIONS
# ══════════════════════════════════════════════════════════════════
@torch.no_grad()
def encode_with_hf_model(model, tokenizer, texts, batch_size=128, max_len=128):
"""Mean-pooled encoding from any HF model."""
all_emb = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
inputs = tokenizer(batch, max_length=max_len, padding=True,
truncation=True, return_tensors="pt").to(DEVICE)
out = model(**inputs)
mask = inputs.attention_mask.unsqueeze(-1).float()
pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1)
all_emb.append(F.normalize(pooled, dim=-1).cpu())
return torch.cat(all_emb)
@torch.no_grad()
def encode_with_student(student, tokenizer, texts, batch_size=128, max_len=128):
"""Encode using the distilled student."""
all_emb = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
inputs = tokenizer(batch, max_length=max_len, padding="max_length",
truncation=True, return_tensors="pt").to(DEVICE)
emb = student(inputs["input_ids"], inputs["attention_mask"])
all_emb.append(emb.cpu())
return torch.cat(all_emb)
# ══════════════════════════════════════════════════════════════════
# EVALUATION
# ══════════════════════════════════════════════════════════════════
def eval_sts(pairs, emb1, emb2):
"""Compute Spearman and Pearson correlation on STS-style task."""
cosines = F.cosine_similarity(emb1, emb2, dim=-1).numpy()
gold = np.array([p["score"] for p in pairs])
spearman = spearmanr(cosines, gold).statistic
pearson = pearsonr(cosines, gold).statistic
return {
"spearman": float(spearman),
"pearson": float(pearson),
"cos_mean": float(cosines.mean()),
"cos_std": float(cosines.std()),
}
# ══════════════════════════════════════════════════════════════════
# STUDENT MODEL (must match training architecture)
# ══════════════════════════════════════════════════════════════════
import torch.nn as nn
class CaptionEncoder(nn.Module):
def __init__(self, vocab_size=30522, max_len=128, d_model=384,
n_heads=6, n_layers=6, d_ff=1536, output_dim=768,
dropout=0.1, pad_token_id=0):
super().__init__()
self.pad_token_id = pad_token_id
self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
self.pos_emb = nn.Embedding(max_len, d_model)
self.emb_norm = nn.LayerNorm(d_model)
self.emb_drop = nn.Dropout(dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=d_ff,
dropout=dropout, activation="gelu", batch_first=True,
norm_first=True)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
self.output_proj = nn.Sequential(
nn.Linear(d_model, d_model), nn.GELU(),
nn.LayerNorm(d_model), nn.Linear(d_model, output_dim))
def forward(self, input_ids, attention_mask=None):
B, L = input_ids.shape
positions = torch.arange(L, device=input_ids.device).unsqueeze(0)
x = self.token_emb(input_ids) + self.pos_emb(positions)
x = self.emb_drop(self.emb_norm(x))
if attention_mask is not None:
kpm = ~attention_mask.bool()
else:
kpm = (input_ids == self.pad_token_id)
x = self.encoder(x, src_key_padding_mask=kpm)
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).float()
else:
mask = (~kpm).unsqueeze(-1).float()
pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
return F.normalize(self.output_proj(pooled), dim=-1)
# ══════════════════════════════════════════════════════════════════
# MAIN
# ══════════════════════════════════════════════════════════════════
def run_benchmarks():
from transformers import AutoModel, AutoTokenizer
import gc
# ── Load benchmarks ──
print(f"\n{'='*65}")
print("LOADING BENCHMARKS")
print(f"{'='*65}")
stsb = load_stsb()
sick = load_sick()
stsb_s1 = [p["sent1"] for p in stsb]
stsb_s2 = [p["sent2"] for p in stsb]
sick_s1 = [p["sent1"] for p in sick]
sick_s2 = [p["sent2"] for p in sick]
results = {}
# ── Evaluate student ──
print(f"\n{'='*65}")
print("EVALUATING: Distilled Consensus Student")
print(f"{'='*65}")
student_tok = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
# Try loading from 200K path first, then 20K
student = None
for save_dir in ["/home/claude/consensus_200k/student",
"/home/claude/distilled_consensus"]:
for ckpt in ["best_model.pt", "final_model.pt"]:
p = os.path.join(save_dir, ckpt)
if os.path.exists(p):
student = CaptionEncoder(
vocab_size=student_tok.vocab_size,
max_len=128, d_model=384, n_heads=6, n_layers=6,
d_ff=1536, output_dim=768, dropout=0.0,
pad_token_id=student_tok.pad_token_id).to(DEVICE)
student.load_state_dict(
torch.load(p, weights_only=True, map_location=DEVICE))
student.eval()
n_params = sum(pp.numel() for pp in student.parameters())
print(f" Loaded: {p}")
print(f" Parameters: {n_params:,}")
break
if student is not None:
break
if student is None:
print(" ERROR: No student checkpoint found!")
return
# Encode
print(" Encoding STS-B...")
s_stsb1 = encode_with_student(student, student_tok, stsb_s1)
s_stsb2 = encode_with_student(student, student_tok, stsb_s2)
print(" Encoding SICK-R...")
s_sick1 = encode_with_student(student, student_tok, sick_s1)
s_sick2 = encode_with_student(student, student_tok, sick_s2)
r_stsb = eval_sts(stsb, s_stsb1, s_stsb2)
r_sick = eval_sts(sick, s_sick1, s_sick2)
results["student"] = {"stsb": r_stsb, "sick": r_sick, "params": n_params}
print(f" STS-B: spearman={r_stsb['spearman']:.4f} pearson={r_stsb['pearson']:.4f}")
print(f" SICK-R: spearman={r_sick['spearman']:.4f} pearson={r_sick['pearson']:.4f}")
del student
gc.collect()
torch.cuda.empty_cache()
# ── Evaluate individual BERTs ──
bert_models = [
("google-bert/bert-base-uncased", "bert-base"),
("answerdotai/ModernBERT-base", "modern-bert"),
("FacebookAI/roberta-base", "roberta"),
("albert/albert-base-v2", "albert"),
("distilbert/distilbert-base-uncased", "distilbert"),
]
for model_name, short_name in bert_models:
print(f"\n{'='*65}")
print(f"EVALUATING: {short_name} ({model_name})")
print(f"{'='*65}")
model = AutoModel.from_pretrained(model_name).to(DEVICE).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
n_p = sum(p.numel() for p in model.parameters())
print(f" Parameters: {n_p:,}")
print(" Encoding STS-B...")
e_stsb1 = encode_with_hf_model(model, tokenizer, stsb_s1)
e_stsb2 = encode_with_hf_model(model, tokenizer, stsb_s2)
print(" Encoding SICK-R...")
e_sick1 = encode_with_hf_model(model, tokenizer, sick_s1)
e_sick2 = encode_with_hf_model(model, tokenizer, sick_s2)
r_stsb = eval_sts(stsb, e_stsb1, e_stsb2)
r_sick = eval_sts(sick, e_sick1, e_sick2)
results[short_name] = {"stsb": r_stsb, "sick": r_sick, "params": n_p}
print(f" STS-B: spearman={r_stsb['spearman']:.4f} pearson={r_stsb['pearson']:.4f}")
print(f" SICK-R: spearman={r_sick['spearman']:.4f} pearson={r_sick['pearson']:.4f}")
del model
gc.collect()
torch.cuda.empty_cache()
# ══════════════════════════════════════════════════════════════
# SUMMARY
# ══════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("SUMMARY")
print(f"{'='*65}")
print(f"\n {'Model':<20} {'Params':>12} {'STS-B ρ':>10} {'SICK-R ρ':>10}")
print(f" {'-'*52}")
# Sort by STS-B spearman
sorted_results = sorted(results.items(),
key=lambda x: x[1]["stsb"]["spearman"], reverse=True)
for name, r in sorted_results:
marker = " β˜…" if name == "student" else ""
print(f" {name:<20} {r['params']:>10,} "
f"{r['stsb']['spearman']:>10.4f} {r['sick']['spearman']:>10.4f}{marker}")
# Student vs best individual
student_stsb = results["student"]["stsb"]["spearman"]
best_name = max((n for n in results if n != "student"),
key=lambda n: results[n]["stsb"]["spearman"])
best_stsb = results[best_name]["stsb"]["spearman"]
best_params = results[best_name]["params"]
student_params = results["student"]["params"]
print(f"\n Student STS-B: {student_stsb:.4f} ({student_params:,} params)")
print(f" Best teacher: {best_stsb:.4f} ({best_name}, {best_params:,} params)")
print(f" Gap: {student_stsb - best_stsb:+.4f}")
print(f" Param ratio: {best_params / student_params:.1f}Γ—")
# Save
save_path = "/home/claude/benchmark_results.json"
with open(save_path, "w") as f:
json.dump(results, f, indent=2, default=str)
print(f"\n Saved to {save_path}")
print(f"\n{'='*65}")
print("DONE")
print(f"{'='*65}")
if __name__ == "__main__":
run_benchmarks()