File size: 4,008 Bytes
52da7b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from dataclasses import dataclass


@dataclass(slots=True)
class ReframrConfig:
    embedding_dim: int = 16
    state_dim: int = 32
    timescales: tuple[float, ...] = (1.0, 0.5, 0.25, 0.125)
    window_size: int = 2
    regularization: float = 1e-3
    min_frequency: int = 1
    max_vocab: int | None = 256
    tokenizer_vocab_size: int = 256
    tokenizer_min_pair_frequency: int = 2
    max_training_examples: int | None = 60000
    max_memory_examples: int | None = None
    max_state_tokens_per_document: int | None = 768
    max_transition_contexts_per_order: int | None = 4096
    max_transition_next_tokens: int = 4
    lowercase: bool = False
    default_reasoning_profile: str = "none"
    layout_profile: str = "rfm-base"
    effective_parameter_target: int = 0

    def to_dict(self) -> dict[str, object]:
        return {
            "embedding_dim": self.embedding_dim,
            "state_dim": self.state_dim,
            "timescales": list(self.timescales),
            "window_size": self.window_size,
            "regularization": self.regularization,
            "min_frequency": self.min_frequency,
            "max_vocab": self.max_vocab,
            "tokenizer_vocab_size": self.tokenizer_vocab_size,
            "tokenizer_min_pair_frequency": self.tokenizer_min_pair_frequency,
            "max_training_examples": self.max_training_examples,
            "max_memory_examples": self.max_memory_examples,
            "max_state_tokens_per_document": self.max_state_tokens_per_document,
            "max_transition_contexts_per_order": self.max_transition_contexts_per_order,
            "max_transition_next_tokens": self.max_transition_next_tokens,
            "lowercase": self.lowercase,
            "default_reasoning_profile": self.default_reasoning_profile,
            "layout_profile": self.layout_profile,
            "effective_parameter_target": self.effective_parameter_target,
        }

    @classmethod
    def from_dict(cls, payload: dict[str, object]) -> "ReframrConfig":
        return cls(
            embedding_dim=int(payload["embedding_dim"]),
            state_dim=int(payload["state_dim"]),
            timescales=tuple(float(value) for value in payload["timescales"]),
            window_size=int(payload["window_size"]),
            regularization=float(payload["regularization"]),
            min_frequency=int(payload["min_frequency"]),
            max_vocab=(
                int(payload.get("max_vocab", 256))
                if payload.get("max_vocab", 256) is not None
                else None
            ),
            tokenizer_vocab_size=int(payload.get("tokenizer_vocab_size", 256)),
            tokenizer_min_pair_frequency=int(payload.get("tokenizer_min_pair_frequency", 2)),
            max_training_examples=(
                int(payload["max_training_examples"])
                if payload.get("max_training_examples") is not None
                else None
            ),
            max_memory_examples=(
                int(payload["max_memory_examples"])
                if payload.get("max_memory_examples") is not None
                else None
            ),
            max_state_tokens_per_document=(
                int(payload["max_state_tokens_per_document"])
                if payload.get("max_state_tokens_per_document") is not None
                else 768
            ),
            max_transition_contexts_per_order=(
                int(payload["max_transition_contexts_per_order"])
                if payload.get("max_transition_contexts_per_order") is not None
                else None
            ),
            max_transition_next_tokens=int(payload.get("max_transition_next_tokens", 4)),
            lowercase=bool(payload.get("lowercase", False)),
            default_reasoning_profile=str(payload.get("default_reasoning_profile", "none")),
            layout_profile=str(payload.get("layout_profile", "rfm-base")),
            effective_parameter_target=int(payload.get("effective_parameter_target", 0)),
        )