from dataclasses import dataclass @dataclass(slots=True) class ReframrConfig: embedding_dim: int = 16 state_dim: int = 32 timescales: tuple[float, ...] = (1.0, 0.5, 0.25, 0.125) window_size: int = 2 regularization: float = 1e-3 min_frequency: int = 1 max_vocab: int | None = 256 tokenizer_vocab_size: int = 256 tokenizer_min_pair_frequency: int = 2 max_training_examples: int | None = 60000 max_transition_contexts_per_order: int | None = 4096 max_transition_next_tokens: int = 4 lowercase: bool = False default_reasoning_profile: str = "none" def to_dict(self) -> dict[str, object]: return { "embedding_dim": self.embedding_dim, "state_dim": self.state_dim, "timescales": list(self.timescales), "window_size": self.window_size, "regularization": self.regularization, "min_frequency": self.min_frequency, "max_vocab": self.max_vocab, "tokenizer_vocab_size": self.tokenizer_vocab_size, "tokenizer_min_pair_frequency": self.tokenizer_min_pair_frequency, "max_training_examples": self.max_training_examples, "max_transition_contexts_per_order": self.max_transition_contexts_per_order, "max_transition_next_tokens": self.max_transition_next_tokens, "lowercase": self.lowercase, "default_reasoning_profile": self.default_reasoning_profile, } @classmethod def from_dict(cls, payload: dict[str, object]) -> "ReframrConfig": return cls( embedding_dim=int(payload["embedding_dim"]), state_dim=int(payload["state_dim"]), timescales=tuple(float(value) for value in payload["timescales"]), window_size=int(payload["window_size"]), regularization=float(payload["regularization"]), min_frequency=int(payload["min_frequency"]), max_vocab=( int(payload.get("max_vocab", 256)) if payload.get("max_vocab", 256) is not None else None ), tokenizer_vocab_size=int(payload.get("tokenizer_vocab_size", 256)), tokenizer_min_pair_frequency=int(payload.get("tokenizer_min_pair_frequency", 2)), max_training_examples=( int(payload["max_training_examples"]) if payload.get("max_training_examples") is not None else None ), max_transition_contexts_per_order=( int(payload["max_transition_contexts_per_order"]) if payload.get("max_transition_contexts_per_order") is not None else None ), max_transition_next_tokens=int(payload.get("max_transition_next_tokens", 4)), lowercase=bool(payload.get("lowercase", False)), default_reasoning_profile=str(payload.get("default_reasoning_profile", "none")), )