SGJM / src /sgjm /modules /backbone.py
adampippert's picture
SGJM 2026.6.5 — code/docs
e51ccda verified
Raw
History Blame Contribute Delete
1.29 kB
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Protocol, Sequence, runtime_checkable
@dataclass(frozen=True)
class BackboneState:
tokens: tuple[int, ...]
latent: tuple[float, ...]
@runtime_checkable
class Backbone(Protocol):
latent_dim: int
def encode(self, tokens: Sequence[int]) -> BackboneState: ...
def step(self, state: BackboneState, token: int) -> BackboneState: ...
@dataclass
class StubBackbone:
latent_dim: int = 16
seed: int = 0
def encode(self, tokens: Sequence[int]) -> BackboneState:
latent = self._latent_for(tokens)
return BackboneState(tokens=tuple(tokens), latent=latent)
def step(self, state: BackboneState, token: int) -> BackboneState:
new_tokens = state.tokens + (token,)
return BackboneState(tokens=new_tokens, latent=self._latent_for(new_tokens))
def _latent_for(self, tokens: Sequence[int]) -> tuple[float, ...]:
acc = [0.0] * self.latent_dim
for i, t in enumerate(tokens):
for d in range(self.latent_dim):
acc[d] += math.sin((self.seed + 1) * (i + 1) * (d + 1) * (int(t) + 1) * 0.017)
norm = math.sqrt(sum(x * x for x in acc)) or 1.0
return tuple(x / norm for x in acc)