#!/usr/bin/env python3 """Run the SRT-Adapter v1.0 community-encoder on the standard English STS suite and print per-task Spearman correlations. This intentionally avoids the full `mteb` library dependency: it uses HuggingFace `datasets` directly for the seven English STS tasks bundled by the MTEB STS English summary, plus the SICK-R and STS-Benchmark subsets that the v1.0 model card cites. Usage: python scripts/run_mteb.py --adapter adapter.pt --output benchmarks/mteb_my_run.json """ from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Iterable import numpy as np import torch import torch.nn.functional as F from scipy.stats import spearmanr from transformers import AutoTokenizer HERE = Path(__file__).resolve().parents[1] sys.path.insert(0, str(HERE / "src")) from srt.adapter import SRTAdapter # noqa: E402 from srt.config import SRTConfig # noqa: E402 # (HF dataset id, config, split, sentence1 col, sentence2 col, score col) TASKS = [ ("mteb/stsbenchmark-sts", None, "test", "sentence1", "sentence2", "score"), ("mteb/sickr-sts", None, "test", "sentence1", "sentence2", "score"), ("mteb/sts12-sts", None, "test", "sentence1", "sentence2", "score"), ("mteb/sts13-sts", None, "test", "sentence1", "sentence2", "score"), ("mteb/sts14-sts", None, "test", "sentence1", "sentence2", "score"), ("mteb/sts15-sts", None, "test", "sentence1", "sentence2", "score"), ("mteb/sts16-sts", None, "test", "sentence1", "sentence2", "score"), ] def load_model(backbone: str, adapter_path: str, device: str): cfg = SRTConfig() cfg.backbone_name = backbone cfg.dtype = "bfloat16" model = SRTAdapter(cfg).to(device) model.load_adapter(adapter_path) model.eval() tok = AutoTokenizer.from_pretrained(backbone) if tok.pad_token is None: tok.pad_token = tok.eos_token return model, tok @torch.no_grad() def encode_batch(model: SRTAdapter, tok, sents: list[str], device: str, max_seq_len: int) -> torch.Tensor: enc = tok(sents, padding=True, truncation=True, max_length=max_seq_len, return_tensors="pt") input_ids = enc["input_ids"].to(device) attn = enc["attention_mask"].to(device) out = model(input_ids=input_ids, attention_mask=attn) encoded = out.community_output.encoded mask = attn.unsqueeze(-1).to(encoded.dtype) pooled = (encoded * mask).sum(1) / mask.sum(1).clamp_min(1.0) return F.normalize(pooled, p=2, dim=-1).float().cpu() def encode_all(model, tok, sents: list[str], device: str, max_seq_len: int, batch_size: int) -> torch.Tensor: chunks = [] for i in range(0, len(sents), batch_size): chunks.append(encode_batch(model, tok, sents[i:i + batch_size], device, max_seq_len)) return torch.cat(chunks, dim=0) def main() -> None: p = argparse.ArgumentParser() p.add_argument("--backbone", default="Qwen/Qwen2.5-7B") p.add_argument("--adapter", default=str(HERE / "adapter.pt")) p.add_argument("--output", default=str(HERE / "benchmarks" / "mteb_run.json")) p.add_argument("--max-seq-len", type=int, default=128) p.add_argument("--batch-size", type=int, default=32) p.add_argument("--device", default=None) args = p.parse_args() from datasets import load_dataset device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") model, tok = load_model(args.backbone, args.adapter, device) results = {"backbone": args.backbone, "adapter": args.adapter, "tasks": {}} spearmans: list[float] = [] for ds_id, cfg_name, split, c1, c2, cs in TASKS: try: ds = load_dataset(ds_id, cfg_name, split=split) if cfg_name else load_dataset(ds_id, split=split) except Exception as e: # noqa: BLE001 print(f"[skip] {ds_id}: {e}", file=sys.stderr) results["tasks"][ds_id] = {"error": str(e)} continue sents1 = [str(x) for x in ds[c1]] sents2 = [str(x) for x in ds[c2]] gold = np.array(ds[cs], dtype=float) emb1 = encode_all(model, tok, sents1, device, args.max_seq_len, args.batch_size) emb2 = encode_all(model, tok, sents2, device, args.max_seq_len, args.batch_size) sims = (emb1 * emb2).sum(-1).numpy() rho = float(spearmanr(sims, gold).statistic) results["tasks"][ds_id] = {"spearman": rho, "n": len(gold)} spearmans.append(rho) print(f"{ds_id:30s} spearman={rho:.4f} n={len(gold)}") results["mean_spearman"] = float(np.mean(spearmans)) if spearmans else None print(f"\nMean Spearman across {len(spearmans)} tasks: {results['mean_spearman']}") out = Path(args.output) out.parent.mkdir(parents=True, exist_ok=True) out.write_text(json.dumps(results, indent=2)) print(f"Wrote {out}") if __name__ == "__main__": main()