| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Protocol, Sequence, runtime_checkable |
|
|
|
|
| @dataclass(frozen=True) |
| class BackboneState: |
| tokens: tuple[int, ...] |
| latent: tuple[float, ...] |
|
|
|
|
| @runtime_checkable |
| class Backbone(Protocol): |
| latent_dim: int |
|
|
| def encode(self, tokens: Sequence[int]) -> BackboneState: ... |
|
|
| def step(self, state: BackboneState, token: int) -> BackboneState: ... |
|
|
|
|
| @dataclass |
| class StubBackbone: |
| latent_dim: int = 16 |
| seed: int = 0 |
|
|
| def encode(self, tokens: Sequence[int]) -> BackboneState: |
| latent = self._latent_for(tokens) |
| return BackboneState(tokens=tuple(tokens), latent=latent) |
|
|
| def step(self, state: BackboneState, token: int) -> BackboneState: |
| new_tokens = state.tokens + (token,) |
| return BackboneState(tokens=new_tokens, latent=self._latent_for(new_tokens)) |
|
|
| def _latent_for(self, tokens: Sequence[int]) -> tuple[float, ...]: |
| acc = [0.0] * self.latent_dim |
| for i, t in enumerate(tokens): |
| for d in range(self.latent_dim): |
| acc[d] += math.sin((self.seed + 1) * (i + 1) * (d + 1) * (int(t) + 1) * 0.017) |
| norm = math.sqrt(sum(x * x for x in acc)) or 1.0 |
| return tuple(x / norm for x in acc) |
|
|