#!/usr/bin/env python3 """Encode sentences with SRT-Adapter v1.0 and print pairwise cosine similarity. Uses `community_output.encoded` (mean-pooled, L2-normalized) as the sentence embedding. This is the output that v15a's contrastive training optimized; it is a real STS embedding (English-mean Spearman ~0.589 across 17 MTEB splits). Usage: cd examples python encode_sentences.py --sentences \ "A man is playing guitar." \ "Someone is performing music." \ "The cat sat on the mat." First run downloads Qwen/Qwen2.5-7B (~15 GB). """ from __future__ import annotations import argparse import json import sys from pathlib import Path import torch import torch.nn.functional as F from transformers import AutoTokenizer HERE = Path(__file__).resolve().parent sys.path.insert(0, str((HERE.parent / "src").resolve())) from srt.adapter import SRTAdapter # noqa: E402 from srt.config import ( # noqa: E402 SRTConfig, MAHConfig, RRMConfig, BENConfig, CommunityConfig, LossConfig, ) def build_config(config_path: Path) -> SRTConfig: raw = json.loads(config_path.read_text()) return SRTConfig( backbone_id=raw["backbone_id"], backbone_dtype=raw["backbone_dtype"], mah_layer_indices=list(raw["mah_layer_indices"]), rrm_inject_indices=list(raw["rrm_inject_indices"]), community_layer_idx=raw["community_layer_idx"], num_mah_layers=raw["num_mah_layers"], mah=MAHConfig(**raw["mah"]), rrm=RRMConfig(**raw["rrm"]), ben=BENConfig(**raw["ben"]), community=CommunityConfig(**raw["community"]), loss=LossConfig(**{ k: v for k, v in raw["loss"].items() if k in LossConfig.__dataclass_fields__ }), ) def load_weights(model: SRTAdapter, weights_path: Path) -> None: if weights_path.suffix == ".safetensors": from safetensors.torch import load_file state = load_file(str(weights_path)) else: state = torch.load(str(weights_path), map_location="cpu", weights_only=True) missing, unexpected = model.load_state_dict(state, strict=False) # SRTAdapter exposes only the trainable submodules; backbone weights are not in the state-dict. # `missing` will list backbone params and that's fine. if unexpected: print(f"warning: unexpected keys {unexpected[:4]}{'...' if len(unexpected) > 4 else ''}", file=sys.stderr) @torch.no_grad() def encode(model: SRTAdapter, tok: AutoTokenizer, sentences: list[str], device: str, max_seq_len: int = 128) -> torch.Tensor: enc = tok( sentences, padding=True, truncation=True, max_length=max_seq_len, return_tensors="pt", ) input_ids = enc["input_ids"].to(device) attn = enc["attention_mask"].to(device) out = model(input_ids=input_ids, attention_mask=attn) encoded = out.community_output.encoded # (B, T, d_community) mask = attn.unsqueeze(-1).to(encoded.dtype) pooled = (encoded * mask).sum(1) / mask.sum(1).clamp_min(1.0) return F.normalize(pooled, p=2, dim=-1) def main() -> None: p = argparse.ArgumentParser() p.add_argument("--sentences", nargs="+", required=True) p.add_argument("--config", default=str(HERE.parent / "config.json")) p.add_argument("--weights", default=None, help="Path to adapter.safetensors or adapter.pt; auto-detected if not provided.") p.add_argument("--device", default=None) p.add_argument("--max-seq-len", type=int, default=128) args = p.parse_args() device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") config = build_config(Path(args.config)) weights = args.weights if weights is None: st = HERE.parent / "adapter.safetensors" pt = HERE.parent / "adapter.pt" weights = str(st if st.exists() else pt) print(f"Loading {config.backbone_id} (frozen) and adapter from {weights} ...", file=sys.stderr) tok = AutoTokenizer.from_pretrained(config.backbone_id) if tok.pad_token is None: tok.pad_token = tok.eos_token model = SRTAdapter(config).to(device) load_weights(model, Path(weights)) model.eval() embs = encode(model, tok, args.sentences, device, args.max_seq_len) sims = embs @ embs.T print("\n=== Sentences ===") for i, s in enumerate(args.sentences): print(f" [{i}] {s}") print("\n=== Cosine similarity matrix ===") n = len(args.sentences) header = " " + " ".join(f" [{j}] " for j in range(n)) print(header) for i in range(n): row = " ".join(f"{sims[i, j].item():+.4f}" for j in range(n)) print(f" [{i}] {row}") if __name__ == "__main__": main()