Antreas commited on
Commit
6b4e066
·
verified ·
1 Parent(s): bbae1b8

Enable AutoModel loading

Browse files
Files changed (1) hide show
  1. config.py +2 -158
config.py CHANGED
@@ -1,161 +1,5 @@
1
- """Model configuration for Ogma."""
2
 
3
- from __future__ import annotations
4
-
5
- from dataclasses import dataclass, field
6
- from enum import StrEnum
7
- from typing import Any
8
 
9
  __all__ = ["OgmaConfig", "VariantType", "PoolingType", "TaskToken"]
10
-
11
-
12
- class VariantType(StrEnum):
13
- """Architecture variant identifiers."""
14
-
15
- TRANSFORMER = "transformer"
16
- DEEP_NARROW = "deep_narrow"
17
- CONV = "conv"
18
- LINEAR_ATTENTION = "linear_attention"
19
- MLP_MIXER = "mlp_mixer"
20
- TRANSFORMER_RESA = "transformer_resa"
21
- GLA = "gla"
22
-
23
-
24
- class PoolingType(StrEnum):
25
- """Pooling strategy identifiers."""
26
-
27
- TASK_TOKEN = "task_token"
28
- LATENT_ATTENTION = "latent_attention"
29
- MEAN = "mean"
30
-
31
-
32
- class TaskToken(StrEnum):
33
- """Task token identifiers for asymmetric encoding."""
34
-
35
- QRY = "QRY"
36
- DOC = "DOC"
37
- SYM = "SYM"
38
-
39
-
40
- @dataclass
41
- class OgmaConfig:
42
- """Configuration for an Ogma model instance.
43
-
44
- Args:
45
- variant: Architecture variant to use.
46
- d_embed: Token embedding dimension (from teacher PCA).
47
- d_model: Internal model dimension after projection.
48
- n_layers: Number of fusion layers/blocks.
49
- n_heads: Number of attention heads (attention variants only).
50
- vocab_size: Vocabulary size for embedding table.
51
- max_seq_len: Maximum sequence length.
52
- matryoshka_dims: Nested output dimensions for Matryoshka.
53
- pooling: Pooling strategy.
54
- d_output: Final output dimension.
55
- ffn_mult: SwiGLU FFN hidden dimension multiplier.
56
- conv_kernel_size: Kernel size for conv variant.
57
- spatial_rank: Rank of spatial mixing in MLP mixer.
58
- n_random_features: Random features for linear attention.
59
- dropout: Dropout rate (0 for inference).
60
- """
61
-
62
- variant: VariantType = VariantType.TRANSFORMER
63
- d_embed: int = 128
64
- d_model: int = 256
65
- n_layers: int = 1
66
- n_heads: int = 4
67
- vocab_size: int = 30_000
68
- max_seq_len: int = 512
69
- matryoshka_dims: list[int] = field(
70
- default_factory=lambda: [32, 64, 128, 256]
71
- )
72
- pooling: PoolingType = PoolingType.TASK_TOKEN
73
- d_output: int = 256
74
- ffn_mult: float = 8 / 3 # SwiGLU: 8/3 * d_model ≈ 683 for d=256
75
- conv_kernel_size: int = 7
76
- spatial_rank: int = 32
77
- n_random_features: int = 128
78
- dropout: float = 0.0
79
-
80
- # ReSA scorer settings
81
- scorer_type: str = "dot"
82
- scorer_alpha_init: float = 0.1
83
- scorer_hidden: int = 0 # 0 defaults to d_head
84
-
85
- # GLA (Gated Linear Attention) settings
86
- gla_expand_k: float = 0.5 # key dim expansion (key_dim = d_model * expand_k)
87
- gla_expand_v: float = 1.0 # value dim expansion (value_dim = d_model * expand_v)
88
- gla_gate_low_rank_dim: int = 16 # low-rank dim for gating projection
89
- gla_gate_logit_normalizer: int = 16 # normalizer for gate logits
90
- gla_use_short_conv: bool = True # whether to use short conv on Q,K,V
91
- gla_conv_size: int = 4 # short conv kernel size
92
-
93
- # Special token IDs
94
- pad_id: int = 0
95
- unk_id: int = 1
96
- bos_id: int = 2
97
- eos_id: int = 3
98
- qry_id: int = 4
99
- doc_id: int = 5
100
- sym_id: int = 6
101
- n_special_tokens: int = 7
102
-
103
- @property
104
- def d_head(self) -> int:
105
- """Per-head dimension."""
106
- return self.d_model // self.n_heads
107
-
108
- @property
109
- def ffn_hidden(self) -> int:
110
- """SwiGLU FFN hidden dimension."""
111
- return int(self.d_model * self.ffn_mult)
112
-
113
- def task_token_id(self, task: TaskToken) -> int:
114
- """Return token ID for a task token."""
115
- mapping = {
116
- TaskToken.QRY: self.qry_id,
117
- TaskToken.DOC: self.doc_id,
118
- TaskToken.SYM: self.sym_id,
119
- }
120
- return mapping[task]
121
-
122
- def to_dict(self) -> dict[str, Any]:
123
- """Serialize config to dictionary."""
124
- return {
125
- "variant": self.variant.value,
126
- "d_embed": self.d_embed,
127
- "d_model": self.d_model,
128
- "n_layers": self.n_layers,
129
- "n_heads": self.n_heads,
130
- "vocab_size": self.vocab_size,
131
- "max_seq_len": self.max_seq_len,
132
- "matryoshka_dims": self.matryoshka_dims,
133
- "pooling": self.pooling.value,
134
- "d_output": self.d_output,
135
- "ffn_mult": self.ffn_mult,
136
- "conv_kernel_size": self.conv_kernel_size,
137
- "spatial_rank": self.spatial_rank,
138
- "n_random_features": self.n_random_features,
139
- "dropout": self.dropout,
140
- "scorer_type": self.scorer_type,
141
- "scorer_alpha_init": self.scorer_alpha_init,
142
- "scorer_hidden": self.scorer_hidden,
143
- "gla_expand_k": self.gla_expand_k,
144
- "gla_expand_v": self.gla_expand_v,
145
- "gla_gate_low_rank_dim": self.gla_gate_low_rank_dim,
146
- "gla_gate_logit_normalizer": self.gla_gate_logit_normalizer,
147
- "gla_use_short_conv": self.gla_use_short_conv,
148
- "gla_conv_size": self.gla_conv_size,
149
- }
150
-
151
- @classmethod
152
- def from_dict(cls, data: dict[str, Any]) -> OgmaConfig:
153
- """Deserialize config from dictionary."""
154
- data = dict(data)
155
- if "variant" in data:
156
- data["variant"] = VariantType(data["variant"])
157
- if "pooling" in data:
158
- data["pooling"] = PoolingType(data["pooling"])
159
- known = {f.name for f in cls.__dataclass_fields__.values()}
160
- filtered = {k: v for k, v in data.items() if k in known}
161
- return cls(**filtered)
 
1
+ """Compatibility exports for Ogma configuration."""
2
 
3
+ from .configuration_ogma import OgmaConfig, PoolingType, TaskToken, VariantType
 
 
 
 
4
 
5
  __all__ = ["OgmaConfig", "VariantType", "PoolingType", "TaskToken"]