srt-adapter-v1.0 / examples /encode_sentences.py
RiverRider's picture
srt-adapter v1.0: v15a checkpoint with capability bench, MTEB English STS, hallucination AUROC
0182cbf verified
#!/usr/bin/env python3
"""Encode sentences with SRT-Adapter v1.0 and print pairwise cosine similarity.
Uses `community_output.encoded` (mean-pooled, L2-normalized) as the sentence
embedding. This is the output that v15a's contrastive training optimized; it
is a real STS embedding (English-mean Spearman ~0.589 across 17 MTEB splits).
Usage:
cd examples
python encode_sentences.py --sentences \
"A man is playing guitar." \
"Someone is performing music." \
"The cat sat on the mat."
First run downloads Qwen/Qwen2.5-7B (~15 GB).
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
HERE = Path(__file__).resolve().parent
sys.path.insert(0, str((HERE.parent / "src").resolve()))
from srt.adapter import SRTAdapter # noqa: E402
from srt.config import ( # noqa: E402
SRTConfig, MAHConfig, RRMConfig, BENConfig, CommunityConfig, LossConfig,
)
def build_config(config_path: Path) -> SRTConfig:
raw = json.loads(config_path.read_text())
return SRTConfig(
backbone_id=raw["backbone_id"],
backbone_dtype=raw["backbone_dtype"],
mah_layer_indices=list(raw["mah_layer_indices"]),
rrm_inject_indices=list(raw["rrm_inject_indices"]),
community_layer_idx=raw["community_layer_idx"],
num_mah_layers=raw["num_mah_layers"],
mah=MAHConfig(**raw["mah"]),
rrm=RRMConfig(**raw["rrm"]),
ben=BENConfig(**raw["ben"]),
community=CommunityConfig(**raw["community"]),
loss=LossConfig(**{
k: v for k, v in raw["loss"].items()
if k in LossConfig.__dataclass_fields__
}),
)
def load_weights(model: SRTAdapter, weights_path: Path) -> None:
if weights_path.suffix == ".safetensors":
from safetensors.torch import load_file
state = load_file(str(weights_path))
else:
state = torch.load(str(weights_path), map_location="cpu", weights_only=True)
missing, unexpected = model.load_state_dict(state, strict=False)
# SRTAdapter exposes only the trainable submodules; backbone weights are not in the state-dict.
# `missing` will list backbone params and that's fine.
if unexpected:
print(f"warning: unexpected keys {unexpected[:4]}{'...' if len(unexpected) > 4 else ''}", file=sys.stderr)
@torch.no_grad()
def encode(model: SRTAdapter, tok: AutoTokenizer, sentences: list[str], device: str, max_seq_len: int = 128) -> torch.Tensor:
enc = tok(
sentences,
padding=True,
truncation=True,
max_length=max_seq_len,
return_tensors="pt",
)
input_ids = enc["input_ids"].to(device)
attn = enc["attention_mask"].to(device)
out = model(input_ids=input_ids, attention_mask=attn)
encoded = out.community_output.encoded # (B, T, d_community)
mask = attn.unsqueeze(-1).to(encoded.dtype)
pooled = (encoded * mask).sum(1) / mask.sum(1).clamp_min(1.0)
return F.normalize(pooled, p=2, dim=-1)
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--sentences", nargs="+", required=True)
p.add_argument("--config", default=str(HERE.parent / "config.json"))
p.add_argument("--weights", default=None, help="Path to adapter.safetensors or adapter.pt; auto-detected if not provided.")
p.add_argument("--device", default=None)
p.add_argument("--max-seq-len", type=int, default=128)
args = p.parse_args()
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
config = build_config(Path(args.config))
weights = args.weights
if weights is None:
st = HERE.parent / "adapter.safetensors"
pt = HERE.parent / "adapter.pt"
weights = str(st if st.exists() else pt)
print(f"Loading {config.backbone_id} (frozen) and adapter from {weights} ...", file=sys.stderr)
tok = AutoTokenizer.from_pretrained(config.backbone_id)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = SRTAdapter(config).to(device)
load_weights(model, Path(weights))
model.eval()
embs = encode(model, tok, args.sentences, device, args.max_seq_len)
sims = embs @ embs.T
print("\n=== Sentences ===")
for i, s in enumerate(args.sentences):
print(f" [{i}] {s}")
print("\n=== Cosine similarity matrix ===")
n = len(args.sentences)
header = " " + " ".join(f" [{j}] " for j in range(n))
print(header)
for i in range(n):
row = " ".join(f"{sims[i, j].item():+.4f}" for j in range(n))
print(f" [{i}] {row}")
if __name__ == "__main__":
main()