| """Unsupervised Community Discovery Head. |
| |
| Discovers discourse communities from backbone hidden states without |
| predefined labels. A discourse community (in Peirce's framework) is a |
| group of language users who share interpretive norms β they assign similar |
| interpretants to the same representamens. |
| |
| The community head runs at an early backbone layer (before MAH hooks) and |
| produces a soft assignment over K learned prototypes. The resulting community |
| vector conditions how MAH computes divergence, so the same sign can produce |
| different divergence patterns in different community contexts. |
| |
| Training signal: the community prototypes are pulled apart by the semiotic |
| losses β if assigning text to different communities helps the model predict |
| divergence better, it will learn to separate them. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from srt.config import CommunityConfig |
|
|
|
|
| @dataclass |
| class CommunityOutput: |
| """Output from community discovery. |
| |
| When the head runs in continuous-trajectory mode (cfg.use_prototypes=False, |
| v8a), `logits` and `weights` are None and `vector == encoded`. |
| """ |
|
|
| logits: torch.Tensor | None |
| weights: torch.Tensor | None |
| vector: torch.Tensor |
| encoded: torch.Tensor |
|
|
|
|
| class CommunityDiscoveryHead(nn.Module): |
| """Soft clustering of hidden states into discourse communities. |
| |
| With cfg.use_prototypes=True (default): pooled hidden state β encoder β |
| cosine similarity to K learned prototypes β soft assignment weights β |
| weighted mixture of prototypes as the community vector. This is the |
| v3βv7 architecture. |
| |
| With cfg.use_prototypes=False (v8a): pooled hidden state β encoder β |
| the encoder output IS the community vector. No discrete basis. Motivated |
| by the v7 PCA finding that prototype tensors barely move from random |
| init; the encoder was already doing the discriminative work and the |
| soft-argmax over K anchors was throwing information away. |
| """ |
|
|
| def __init__(self, cfg: CommunityConfig, d_backbone: int) -> None: |
| super().__init__() |
| self.temperature = cfg.temperature |
| self.use_prototypes = cfg.use_prototypes |
|
|
| |
| self.encoder = nn.Sequential( |
| nn.Linear(d_backbone, cfg.d_community), |
| nn.SiLU(), |
| ) |
|
|
| |
| if cfg.use_prototypes: |
| self.prototypes = nn.Embedding(cfg.num_prototypes, cfg.d_community) |
| else: |
| self.prototypes = None |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| ) -> CommunityOutput: |
| """Discover community from hidden states. |
| |
| Args: |
| hidden_states: (B, T, d_backbone) from an early backbone layer. |
| attention_mask: (B, T) padding mask (1 = real, 0 = pad). Optional. |
| |
| Returns: |
| CommunityOutput. In prototype mode, logits/weights are populated |
| and vector is the prototype-weighted mixture. In trajectory mode |
| (use_prototypes=False), logits and weights are None and vector |
| equals encoded. |
| """ |
| |
| if attention_mask is not None: |
| mask = attention_mask.unsqueeze(-1).to(hidden_states.dtype) |
| pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) |
| else: |
| pooled = hidden_states.mean(dim=1) |
| encoded = self.encoder(pooled) |
|
|
| if not self.use_prototypes: |
| |
| return CommunityOutput( |
| logits=None, weights=None, vector=encoded, encoded=encoded, |
| ) |
|
|
| |
| encoded_norm = F.normalize(encoded, dim=-1) |
| proto_norm = F.normalize(self.prototypes.weight, dim=-1) |
| logits = (encoded_norm @ proto_norm.T) / self.temperature |
|
|
| weights = F.softmax(logits, dim=-1) |
| vector = weights @ self.prototypes.weight |
|
|
| return CommunityOutput( |
| logits=logits, weights=weights, vector=vector, encoded=encoded, |
| ) |
|
|