File size: 4,700 Bytes
0182cbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python3
"""Encode sentences with SRT-Adapter v1.0 and print pairwise cosine similarity.

Uses `community_output.encoded` (mean-pooled, L2-normalized) as the sentence
embedding. This is the output that v15a's contrastive training optimized; it
is a real STS embedding (English-mean Spearman ~0.589 across 17 MTEB splits).

Usage:
    cd examples
    python encode_sentences.py --sentences \
        "A man is playing guitar." \
        "Someone is performing music." \
        "The cat sat on the mat."

First run downloads Qwen/Qwen2.5-7B (~15 GB).
"""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer

HERE = Path(__file__).resolve().parent
sys.path.insert(0, str((HERE.parent / "src").resolve()))

from srt.adapter import SRTAdapter  # noqa: E402
from srt.config import (  # noqa: E402
    SRTConfig, MAHConfig, RRMConfig, BENConfig, CommunityConfig, LossConfig,
)


def build_config(config_path: Path) -> SRTConfig:
    raw = json.loads(config_path.read_text())
    return SRTConfig(
        backbone_id=raw["backbone_id"],
        backbone_dtype=raw["backbone_dtype"],
        mah_layer_indices=list(raw["mah_layer_indices"]),
        rrm_inject_indices=list(raw["rrm_inject_indices"]),
        community_layer_idx=raw["community_layer_idx"],
        num_mah_layers=raw["num_mah_layers"],
        mah=MAHConfig(**raw["mah"]),
        rrm=RRMConfig(**raw["rrm"]),
        ben=BENConfig(**raw["ben"]),
        community=CommunityConfig(**raw["community"]),
        loss=LossConfig(**{
            k: v for k, v in raw["loss"].items()
            if k in LossConfig.__dataclass_fields__
        }),
    )


def load_weights(model: SRTAdapter, weights_path: Path) -> None:
    if weights_path.suffix == ".safetensors":
        from safetensors.torch import load_file
        state = load_file(str(weights_path))
    else:
        state = torch.load(str(weights_path), map_location="cpu", weights_only=True)
    missing, unexpected = model.load_state_dict(state, strict=False)
    # SRTAdapter exposes only the trainable submodules; backbone weights are not in the state-dict.
    # `missing` will list backbone params and that's fine.
    if unexpected:
        print(f"warning: unexpected keys {unexpected[:4]}{'...' if len(unexpected) > 4 else ''}", file=sys.stderr)


@torch.no_grad()
def encode(model: SRTAdapter, tok: AutoTokenizer, sentences: list[str], device: str, max_seq_len: int = 128) -> torch.Tensor:
    enc = tok(
        sentences,
        padding=True,
        truncation=True,
        max_length=max_seq_len,
        return_tensors="pt",
    )
    input_ids = enc["input_ids"].to(device)
    attn = enc["attention_mask"].to(device)
    out = model(input_ids=input_ids, attention_mask=attn)
    encoded = out.community_output.encoded  # (B, T, d_community)
    mask = attn.unsqueeze(-1).to(encoded.dtype)
    pooled = (encoded * mask).sum(1) / mask.sum(1).clamp_min(1.0)
    return F.normalize(pooled, p=2, dim=-1)


def main() -> None:
    p = argparse.ArgumentParser()
    p.add_argument("--sentences", nargs="+", required=True)
    p.add_argument("--config", default=str(HERE.parent / "config.json"))
    p.add_argument("--weights", default=None, help="Path to adapter.safetensors or adapter.pt; auto-detected if not provided.")
    p.add_argument("--device", default=None)
    p.add_argument("--max-seq-len", type=int, default=128)
    args = p.parse_args()

    device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
    config = build_config(Path(args.config))

    weights = args.weights
    if weights is None:
        st = HERE.parent / "adapter.safetensors"
        pt = HERE.parent / "adapter.pt"
        weights = str(st if st.exists() else pt)
    print(f"Loading {config.backbone_id} (frozen) and adapter from {weights} ...", file=sys.stderr)

    tok = AutoTokenizer.from_pretrained(config.backbone_id)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    model = SRTAdapter(config).to(device)
    load_weights(model, Path(weights))
    model.eval()

    embs = encode(model, tok, args.sentences, device, args.max_seq_len)
    sims = embs @ embs.T

    print("\n=== Sentences ===")
    for i, s in enumerate(args.sentences):
        print(f"  [{i}] {s}")

    print("\n=== Cosine similarity matrix ===")
    n = len(args.sentences)
    header = "      " + " ".join(f"  [{j}] " for j in range(n))
    print(header)
    for i in range(n):
        row = " ".join(f"{sims[i, j].item():+.4f}" for j in range(n))
        print(f"  [{i}] {row}")


if __name__ == "__main__":
    main()