srt-adapter v1.0: v15a checkpoint with capability bench, MTEB English STS, hallucination AUROC
0182cbf verified | #!/usr/bin/env python3 | |
| """Encode sentences with SRT-Adapter v1.0 and print pairwise cosine similarity. | |
| Uses `community_output.encoded` (mean-pooled, L2-normalized) as the sentence | |
| embedding. This is the output that v15a's contrastive training optimized; it | |
| is a real STS embedding (English-mean Spearman ~0.589 across 17 MTEB splits). | |
| Usage: | |
| cd examples | |
| python encode_sentences.py --sentences \ | |
| "A man is playing guitar." \ | |
| "Someone is performing music." \ | |
| "The cat sat on the mat." | |
| First run downloads Qwen/Qwen2.5-7B (~15 GB). | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer | |
| HERE = Path(__file__).resolve().parent | |
| sys.path.insert(0, str((HERE.parent / "src").resolve())) | |
| from srt.adapter import SRTAdapter # noqa: E402 | |
| from srt.config import ( # noqa: E402 | |
| SRTConfig, MAHConfig, RRMConfig, BENConfig, CommunityConfig, LossConfig, | |
| ) | |
| def build_config(config_path: Path) -> SRTConfig: | |
| raw = json.loads(config_path.read_text()) | |
| return SRTConfig( | |
| backbone_id=raw["backbone_id"], | |
| backbone_dtype=raw["backbone_dtype"], | |
| mah_layer_indices=list(raw["mah_layer_indices"]), | |
| rrm_inject_indices=list(raw["rrm_inject_indices"]), | |
| community_layer_idx=raw["community_layer_idx"], | |
| num_mah_layers=raw["num_mah_layers"], | |
| mah=MAHConfig(**raw["mah"]), | |
| rrm=RRMConfig(**raw["rrm"]), | |
| ben=BENConfig(**raw["ben"]), | |
| community=CommunityConfig(**raw["community"]), | |
| loss=LossConfig(**{ | |
| k: v for k, v in raw["loss"].items() | |
| if k in LossConfig.__dataclass_fields__ | |
| }), | |
| ) | |
| def load_weights(model: SRTAdapter, weights_path: Path) -> None: | |
| if weights_path.suffix == ".safetensors": | |
| from safetensors.torch import load_file | |
| state = load_file(str(weights_path)) | |
| else: | |
| state = torch.load(str(weights_path), map_location="cpu", weights_only=True) | |
| missing, unexpected = model.load_state_dict(state, strict=False) | |
| # SRTAdapter exposes only the trainable submodules; backbone weights are not in the state-dict. | |
| # `missing` will list backbone params and that's fine. | |
| if unexpected: | |
| print(f"warning: unexpected keys {unexpected[:4]}{'...' if len(unexpected) > 4 else ''}", file=sys.stderr) | |
| def encode(model: SRTAdapter, tok: AutoTokenizer, sentences: list[str], device: str, max_seq_len: int = 128) -> torch.Tensor: | |
| enc = tok( | |
| sentences, | |
| padding=True, | |
| truncation=True, | |
| max_length=max_seq_len, | |
| return_tensors="pt", | |
| ) | |
| input_ids = enc["input_ids"].to(device) | |
| attn = enc["attention_mask"].to(device) | |
| out = model(input_ids=input_ids, attention_mask=attn) | |
| encoded = out.community_output.encoded # (B, T, d_community) | |
| mask = attn.unsqueeze(-1).to(encoded.dtype) | |
| pooled = (encoded * mask).sum(1) / mask.sum(1).clamp_min(1.0) | |
| return F.normalize(pooled, p=2, dim=-1) | |
| def main() -> None: | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--sentences", nargs="+", required=True) | |
| p.add_argument("--config", default=str(HERE.parent / "config.json")) | |
| p.add_argument("--weights", default=None, help="Path to adapter.safetensors or adapter.pt; auto-detected if not provided.") | |
| p.add_argument("--device", default=None) | |
| p.add_argument("--max-seq-len", type=int, default=128) | |
| args = p.parse_args() | |
| device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| config = build_config(Path(args.config)) | |
| weights = args.weights | |
| if weights is None: | |
| st = HERE.parent / "adapter.safetensors" | |
| pt = HERE.parent / "adapter.pt" | |
| weights = str(st if st.exists() else pt) | |
| print(f"Loading {config.backbone_id} (frozen) and adapter from {weights} ...", file=sys.stderr) | |
| tok = AutoTokenizer.from_pretrained(config.backbone_id) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| model = SRTAdapter(config).to(device) | |
| load_weights(model, Path(weights)) | |
| model.eval() | |
| embs = encode(model, tok, args.sentences, device, args.max_seq_len) | |
| sims = embs @ embs.T | |
| print("\n=== Sentences ===") | |
| for i, s in enumerate(args.sentences): | |
| print(f" [{i}] {s}") | |
| print("\n=== Cosine similarity matrix ===") | |
| n = len(args.sentences) | |
| header = " " + " ".join(f" [{j}] " for j in range(n)) | |
| print(header) | |
| for i in range(n): | |
| row = " ".join(f"{sims[i, j].item():+.4f}" for j in range(n)) | |
| print(f" [{i}] {row}") | |
| if __name__ == "__main__": | |
| main() | |