srt-adapter v1.0: v15a checkpoint with capability bench, MTEB English STS, hallucination AUROC
0182cbf verified | #!/usr/bin/env python3 | |
| """Run the SRT-Adapter v1.0 community-encoder on the standard English STS suite | |
| and print per-task Spearman correlations. | |
| This intentionally avoids the full `mteb` library dependency: it uses | |
| HuggingFace `datasets` directly for the seven English STS tasks bundled by | |
| the MTEB STS English summary, plus the SICK-R and STS-Benchmark subsets that | |
| the v1.0 model card cites. | |
| Usage: | |
| python scripts/run_mteb.py --adapter adapter.pt --output benchmarks/mteb_my_run.json | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Iterable | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from scipy.stats import spearmanr | |
| from transformers import AutoTokenizer | |
| HERE = Path(__file__).resolve().parents[1] | |
| sys.path.insert(0, str(HERE / "src")) | |
| from srt.adapter import SRTAdapter # noqa: E402 | |
| from srt.config import SRTConfig # noqa: E402 | |
| # (HF dataset id, config, split, sentence1 col, sentence2 col, score col) | |
| TASKS = [ | |
| ("mteb/stsbenchmark-sts", None, "test", "sentence1", "sentence2", "score"), | |
| ("mteb/sickr-sts", None, "test", "sentence1", "sentence2", "score"), | |
| ("mteb/sts12-sts", None, "test", "sentence1", "sentence2", "score"), | |
| ("mteb/sts13-sts", None, "test", "sentence1", "sentence2", "score"), | |
| ("mteb/sts14-sts", None, "test", "sentence1", "sentence2", "score"), | |
| ("mteb/sts15-sts", None, "test", "sentence1", "sentence2", "score"), | |
| ("mteb/sts16-sts", None, "test", "sentence1", "sentence2", "score"), | |
| ] | |
| def load_model(backbone: str, adapter_path: str, device: str): | |
| cfg = SRTConfig() | |
| cfg.backbone_name = backbone | |
| cfg.dtype = "bfloat16" | |
| model = SRTAdapter(cfg).to(device) | |
| model.load_adapter(adapter_path) | |
| model.eval() | |
| tok = AutoTokenizer.from_pretrained(backbone) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| return model, tok | |
| def encode_batch(model: SRTAdapter, tok, sents: list[str], device: str, max_seq_len: int) -> torch.Tensor: | |
| enc = tok(sents, padding=True, truncation=True, max_length=max_seq_len, return_tensors="pt") | |
| input_ids = enc["input_ids"].to(device) | |
| attn = enc["attention_mask"].to(device) | |
| out = model(input_ids=input_ids, attention_mask=attn) | |
| encoded = out.community_output.encoded | |
| mask = attn.unsqueeze(-1).to(encoded.dtype) | |
| pooled = (encoded * mask).sum(1) / mask.sum(1).clamp_min(1.0) | |
| return F.normalize(pooled, p=2, dim=-1).float().cpu() | |
| def encode_all(model, tok, sents: list[str], device: str, max_seq_len: int, batch_size: int) -> torch.Tensor: | |
| chunks = [] | |
| for i in range(0, len(sents), batch_size): | |
| chunks.append(encode_batch(model, tok, sents[i:i + batch_size], device, max_seq_len)) | |
| return torch.cat(chunks, dim=0) | |
| def main() -> None: | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--backbone", default="Qwen/Qwen2.5-7B") | |
| p.add_argument("--adapter", default=str(HERE / "adapter.pt")) | |
| p.add_argument("--output", default=str(HERE / "benchmarks" / "mteb_run.json")) | |
| p.add_argument("--max-seq-len", type=int, default=128) | |
| p.add_argument("--batch-size", type=int, default=32) | |
| p.add_argument("--device", default=None) | |
| args = p.parse_args() | |
| from datasets import load_dataset | |
| device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| model, tok = load_model(args.backbone, args.adapter, device) | |
| results = {"backbone": args.backbone, "adapter": args.adapter, "tasks": {}} | |
| spearmans: list[float] = [] | |
| for ds_id, cfg_name, split, c1, c2, cs in TASKS: | |
| try: | |
| ds = load_dataset(ds_id, cfg_name, split=split) if cfg_name else load_dataset(ds_id, split=split) | |
| except Exception as e: # noqa: BLE001 | |
| print(f"[skip] {ds_id}: {e}", file=sys.stderr) | |
| results["tasks"][ds_id] = {"error": str(e)} | |
| continue | |
| sents1 = [str(x) for x in ds[c1]] | |
| sents2 = [str(x) for x in ds[c2]] | |
| gold = np.array(ds[cs], dtype=float) | |
| emb1 = encode_all(model, tok, sents1, device, args.max_seq_len, args.batch_size) | |
| emb2 = encode_all(model, tok, sents2, device, args.max_seq_len, args.batch_size) | |
| sims = (emb1 * emb2).sum(-1).numpy() | |
| rho = float(spearmanr(sims, gold).statistic) | |
| results["tasks"][ds_id] = {"spearman": rho, "n": len(gold)} | |
| spearmans.append(rho) | |
| print(f"{ds_id:30s} spearman={rho:.4f} n={len(gold)}") | |
| results["mean_spearman"] = float(np.mean(spearmans)) if spearmans else None | |
| print(f"\nMean Spearman across {len(spearmans)} tasks: {results['mean_spearman']}") | |
| out = Path(args.output) | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| out.write_text(json.dumps(results, indent=2)) | |
| print(f"Wrote {out}") | |
| if __name__ == "__main__": | |
| main() | |