| # Examples |
|
|
| ## `load_and_score.py` |
|
|
| End-to-end demo. Loads the v8a adapter on top of a frozen `Qwen/Qwen2.5-7B` and prints all four semiotic readouts for an input passage. |
|
|
| ```bash |
| cd examples |
| pip install -r ../requirements.txt |
| python load_and_score.py --text "Vaccine mandates are an obvious public health win." |
| ``` |
|
|
| First run downloads `Qwen/Qwen2.5-7B` (~15 GB) from HuggingFace. |
|
|
| ## Programmatic use |
|
|
| ```python |
| import json, sys, torch |
| from pathlib import Path |
| sys.path.insert(0, "src") |
| |
| from srt.config import (SRTConfig, MAHConfig, RRMConfig, BENConfig, |
| CommunityConfig, LossConfig) |
| from srt.adapter import SRTAdapter |
| from transformers import AutoTokenizer |
| |
| raw = json.loads(Path("config.json").read_text()) |
| config = SRTConfig( |
| backbone_id = raw["backbone_id"], |
| backbone_dtype = raw["backbone_dtype"], |
| mah_layer_indices = list(raw["mah_layer_indices"]), |
| rrm_inject_indices = list(raw["rrm_inject_indices"]), |
| community_layer_idx= raw["community_layer_idx"], |
| num_mah_layers = raw["num_mah_layers"], |
| mah = MAHConfig(**raw["mah"]), |
| rrm = RRMConfig(**raw["rrm"]), |
| ben = BENConfig(**raw["ben"]), |
| community = CommunityConfig(**raw["community"]), |
| loss = LossConfig(**{k: v for k, v in raw["loss"].items() |
| if k in LossConfig.__dataclass_fields__}), |
| ) |
| |
| model = SRTAdapter(config).cuda().eval() |
| state = torch.load("adapter.pt", map_location="cuda") |
| model.load_state_dict(state, strict=False) |
| |
| tok = AutoTokenizer.from_pretrained(config.backbone_id) |
| enc = tok("Freedom means different things to different people.", |
| return_tensors="pt").to("cuda") |
| |
| with torch.no_grad(): |
| out = model(input_ids=enc.input_ids, attention_mask=enc.attention_mask) |
| |
| print("logits :", out.logits.shape) # (1, T, V) |
| print("community vec :", out.community_output.vector.shape) # (1, 64) |
| print("divergences :", [d.shape for d in out.divergences]) # 3× (1, T, 256) |
| print("r_hat :", out.ben_output.r_hat.shape) # (1, T) |
| print("regime logits :", out.ben_output.regime_logits.shape) # (1, T, 2) |
| ``` |
|
|