File size: 4,700 Bytes
0182cbf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | #!/usr/bin/env python3
"""Encode sentences with SRT-Adapter v1.0 and print pairwise cosine similarity.
Uses `community_output.encoded` (mean-pooled, L2-normalized) as the sentence
embedding. This is the output that v15a's contrastive training optimized; it
is a real STS embedding (English-mean Spearman ~0.589 across 17 MTEB splits).
Usage:
cd examples
python encode_sentences.py --sentences \
"A man is playing guitar." \
"Someone is performing music." \
"The cat sat on the mat."
First run downloads Qwen/Qwen2.5-7B (~15 GB).
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
HERE = Path(__file__).resolve().parent
sys.path.insert(0, str((HERE.parent / "src").resolve()))
from srt.adapter import SRTAdapter # noqa: E402
from srt.config import ( # noqa: E402
SRTConfig, MAHConfig, RRMConfig, BENConfig, CommunityConfig, LossConfig,
)
def build_config(config_path: Path) -> SRTConfig:
raw = json.loads(config_path.read_text())
return SRTConfig(
backbone_id=raw["backbone_id"],
backbone_dtype=raw["backbone_dtype"],
mah_layer_indices=list(raw["mah_layer_indices"]),
rrm_inject_indices=list(raw["rrm_inject_indices"]),
community_layer_idx=raw["community_layer_idx"],
num_mah_layers=raw["num_mah_layers"],
mah=MAHConfig(**raw["mah"]),
rrm=RRMConfig(**raw["rrm"]),
ben=BENConfig(**raw["ben"]),
community=CommunityConfig(**raw["community"]),
loss=LossConfig(**{
k: v for k, v in raw["loss"].items()
if k in LossConfig.__dataclass_fields__
}),
)
def load_weights(model: SRTAdapter, weights_path: Path) -> None:
if weights_path.suffix == ".safetensors":
from safetensors.torch import load_file
state = load_file(str(weights_path))
else:
state = torch.load(str(weights_path), map_location="cpu", weights_only=True)
missing, unexpected = model.load_state_dict(state, strict=False)
# SRTAdapter exposes only the trainable submodules; backbone weights are not in the state-dict.
# `missing` will list backbone params and that's fine.
if unexpected:
print(f"warning: unexpected keys {unexpected[:4]}{'...' if len(unexpected) > 4 else ''}", file=sys.stderr)
@torch.no_grad()
def encode(model: SRTAdapter, tok: AutoTokenizer, sentences: list[str], device: str, max_seq_len: int = 128) -> torch.Tensor:
enc = tok(
sentences,
padding=True,
truncation=True,
max_length=max_seq_len,
return_tensors="pt",
)
input_ids = enc["input_ids"].to(device)
attn = enc["attention_mask"].to(device)
out = model(input_ids=input_ids, attention_mask=attn)
encoded = out.community_output.encoded # (B, T, d_community)
mask = attn.unsqueeze(-1).to(encoded.dtype)
pooled = (encoded * mask).sum(1) / mask.sum(1).clamp_min(1.0)
return F.normalize(pooled, p=2, dim=-1)
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--sentences", nargs="+", required=True)
p.add_argument("--config", default=str(HERE.parent / "config.json"))
p.add_argument("--weights", default=None, help="Path to adapter.safetensors or adapter.pt; auto-detected if not provided.")
p.add_argument("--device", default=None)
p.add_argument("--max-seq-len", type=int, default=128)
args = p.parse_args()
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
config = build_config(Path(args.config))
weights = args.weights
if weights is None:
st = HERE.parent / "adapter.safetensors"
pt = HERE.parent / "adapter.pt"
weights = str(st if st.exists() else pt)
print(f"Loading {config.backbone_id} (frozen) and adapter from {weights} ...", file=sys.stderr)
tok = AutoTokenizer.from_pretrained(config.backbone_id)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = SRTAdapter(config).to(device)
load_weights(model, Path(weights))
model.eval()
embs = encode(model, tok, args.sentences, device, args.max_seq_len)
sims = embs @ embs.T
print("\n=== Sentences ===")
for i, s in enumerate(args.sentences):
print(f" [{i}] {s}")
print("\n=== Cosine similarity matrix ===")
n = len(args.sentences)
header = " " + " ".join(f" [{j}] " for j in range(n))
print(header)
for i in range(n):
row = " ".join(f"{sims[i, j].item():+.4f}" for j in range(n))
print(f" [{i}] {row}")
if __name__ == "__main__":
main()
|