srt-adapter-v1.0 / scripts /run_mteb.py
RiverRider's picture
srt-adapter v1.0: v15a checkpoint with capability bench, MTEB English STS, hallucination AUROC
0182cbf verified
#!/usr/bin/env python3
"""Run the SRT-Adapter v1.0 community-encoder on the standard English STS suite
and print per-task Spearman correlations.
This intentionally avoids the full `mteb` library dependency: it uses
HuggingFace `datasets` directly for the seven English STS tasks bundled by
the MTEB STS English summary, plus the SICK-R and STS-Benchmark subsets that
the v1.0 model card cites.
Usage:
python scripts/run_mteb.py --adapter adapter.pt --output benchmarks/mteb_my_run.json
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Iterable
import numpy as np
import torch
import torch.nn.functional as F
from scipy.stats import spearmanr
from transformers import AutoTokenizer
HERE = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(HERE / "src"))
from srt.adapter import SRTAdapter # noqa: E402
from srt.config import SRTConfig # noqa: E402
# (HF dataset id, config, split, sentence1 col, sentence2 col, score col)
TASKS = [
("mteb/stsbenchmark-sts", None, "test", "sentence1", "sentence2", "score"),
("mteb/sickr-sts", None, "test", "sentence1", "sentence2", "score"),
("mteb/sts12-sts", None, "test", "sentence1", "sentence2", "score"),
("mteb/sts13-sts", None, "test", "sentence1", "sentence2", "score"),
("mteb/sts14-sts", None, "test", "sentence1", "sentence2", "score"),
("mteb/sts15-sts", None, "test", "sentence1", "sentence2", "score"),
("mteb/sts16-sts", None, "test", "sentence1", "sentence2", "score"),
]
def load_model(backbone: str, adapter_path: str, device: str):
cfg = SRTConfig()
cfg.backbone_name = backbone
cfg.dtype = "bfloat16"
model = SRTAdapter(cfg).to(device)
model.load_adapter(adapter_path)
model.eval()
tok = AutoTokenizer.from_pretrained(backbone)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
return model, tok
@torch.no_grad()
def encode_batch(model: SRTAdapter, tok, sents: list[str], device: str, max_seq_len: int) -> torch.Tensor:
enc = tok(sents, padding=True, truncation=True, max_length=max_seq_len, return_tensors="pt")
input_ids = enc["input_ids"].to(device)
attn = enc["attention_mask"].to(device)
out = model(input_ids=input_ids, attention_mask=attn)
encoded = out.community_output.encoded
mask = attn.unsqueeze(-1).to(encoded.dtype)
pooled = (encoded * mask).sum(1) / mask.sum(1).clamp_min(1.0)
return F.normalize(pooled, p=2, dim=-1).float().cpu()
def encode_all(model, tok, sents: list[str], device: str, max_seq_len: int, batch_size: int) -> torch.Tensor:
chunks = []
for i in range(0, len(sents), batch_size):
chunks.append(encode_batch(model, tok, sents[i:i + batch_size], device, max_seq_len))
return torch.cat(chunks, dim=0)
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--backbone", default="Qwen/Qwen2.5-7B")
p.add_argument("--adapter", default=str(HERE / "adapter.pt"))
p.add_argument("--output", default=str(HERE / "benchmarks" / "mteb_run.json"))
p.add_argument("--max-seq-len", type=int, default=128)
p.add_argument("--batch-size", type=int, default=32)
p.add_argument("--device", default=None)
args = p.parse_args()
from datasets import load_dataset
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
model, tok = load_model(args.backbone, args.adapter, device)
results = {"backbone": args.backbone, "adapter": args.adapter, "tasks": {}}
spearmans: list[float] = []
for ds_id, cfg_name, split, c1, c2, cs in TASKS:
try:
ds = load_dataset(ds_id, cfg_name, split=split) if cfg_name else load_dataset(ds_id, split=split)
except Exception as e: # noqa: BLE001
print(f"[skip] {ds_id}: {e}", file=sys.stderr)
results["tasks"][ds_id] = {"error": str(e)}
continue
sents1 = [str(x) for x in ds[c1]]
sents2 = [str(x) for x in ds[c2]]
gold = np.array(ds[cs], dtype=float)
emb1 = encode_all(model, tok, sents1, device, args.max_seq_len, args.batch_size)
emb2 = encode_all(model, tok, sents2, device, args.max_seq_len, args.batch_size)
sims = (emb1 * emb2).sum(-1).numpy()
rho = float(spearmanr(sims, gold).statistic)
results["tasks"][ds_id] = {"spearman": rho, "n": len(gold)}
spearmans.append(rho)
print(f"{ds_id:30s} spearman={rho:.4f} n={len(gold)}")
results["mean_spearman"] = float(np.mean(spearmans)) if spearmans else None
print(f"\nMean Spearman across {len(spearmans)} tasks: {results['mean_spearman']}")
out = Path(args.output)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(results, indent=2))
print(f"Wrote {out}")
if __name__ == "__main__":
main()