ryanscottbarrett commited on
Commit
c7739b6
·
verified ·
1 Parent(s): 192c355

Upload braille256-v6: Lattice-aware multimodal Braille model

Browse files
README.md ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # braille256-v6: Lattice-Aware Multimodal Braille Model
2
+
3
+ **The first LLM with explicit dot-lattice structure in its architecture.**
4
+
5
+ ## Model Description
6
+
7
+ braille256-v6 builds on the multimodal foundation of v5, integrating formal lattice theory into the training pipeline. This is not just a Braille-native model—it's a **lattice-native** model that understands the mathematical structure of Braille at the architectural level.
8
+
9
+ ### Key Innovations
10
+
11
+ | Feature | Description |
12
+ |---------|-------------|
13
+ | **Lattice Attention** | Attention scores incorporate Hamming-based similarity on Braille cells |
14
+ | **Lattice Embeddings** | Token embeddings initialized to respect Boolean lattice structure |
15
+ | **Morphological Regularization** | Training loss includes equivariance under erosion/dilation |
16
+ | **Haptic Evaluation** | New metrics for tactile quality of outputs |
17
+
18
+ ## Architecture
19
+
20
+ ```
21
+ Parameters: ~12M
22
+ Layers: 4
23
+ Heads: 4
24
+ Hidden: 256
25
+ Vocab: 32,000 (SentencePiece)
26
+ Context: 512
27
+ ```
28
+
29
+ ### Lattice Attention
30
+
31
+ Standard transformer attention computes:
32
+ ```
33
+ Attention(Q, K, V) = softmax(QK^T / √d) V
34
+ ```
35
+
36
+ Lattice attention blends this with Braille-aware similarity:
37
+ ```
38
+ LatticeAttn = (1-λ) * StandardAttn + λ * HammingAttn
39
+
40
+ where HammingAttn[i,j] = 8 - popcount(token[i] XOR token[j])
41
+ ```
42
+
43
+ This gives the model an inductive bias toward understanding Braille structure.
44
+
45
+ ### Lattice Embeddings
46
+
47
+ For the first 256 tokens (corresponding to Braille cells), embeddings are initialized as:
48
+ ```python
49
+ embedding[i] = Σ basis[b] for each raised dot b in cell i
50
+ ```
51
+
52
+ This means similar Braille cells (low Hamming distance) start with similar embeddings.
53
+
54
+ ### Morphological Regularization
55
+
56
+ Training includes a regularization term:
57
+ ```
58
+ L_morph = ReLU(||emb - erode(emb)|| - ||emb - dilate(emb)||)
59
+ ```
60
+
61
+ This encourages embeddings to respect the lattice ordering: `erode(x) ≤ x ≤ dilate(x)`.
62
+
63
+ ## Theoretical Foundation
64
+
65
+ This model implements the formal theory from:
66
+
67
+ **"Theoretical Foundations for 8-Dot Braille-Native LLMs"**
68
+
69
+ Key theoretical components:
70
+ 1. **Braille Lattice**: Boolean algebra (B⁸, ∧, ∨, ¬) with 256 elements
71
+ 2. **Morphological Operators**: Erosion, dilation, opening, closing
72
+ 3. **Modality-Invariant Representation**: (modality, sequence, embedding) triple
73
+ 4. **Lattice Metrics**: Hamming distance, Jaccard similarity
74
+
75
+ See: `braille_lattice_theory.py` for full implementation.
76
+
77
+ ## Modality Support
78
+
79
+ | Modality | Header | Status |
80
+ |----------|--------|--------|
81
+ | TEXT | ⣿⠁ | ✅ Trained |
82
+ | IMAGE | ⣿⠃ | ✅ Trained |
83
+ | AUDIO | ⣿⠇ | ✅ Trained |
84
+ | BINARY | ⣿⠏ | ✅ Trained |
85
+ | VIDEO | ⣿⠗ | 🔄 Framework ready |
86
+
87
+ ## Haptic Evaluation Metrics
88
+
89
+ v6 introduces new evaluation metrics for tactile quality:
90
+
91
+ | Metric | Description | Target | **Achieved** |
92
+ |--------|-------------|--------|-------------|
93
+ | **Lattice Coherence** | Adjacent tokens have low Hamming distance | > 0.7 | **0.743** ✅ |
94
+ | **Morphological Stability** | Outputs stable under erosion/dilation | > 0.5 | **0.453** |
95
+ | **Haptic Score** | Combined tactile quality metric | > 0.5 | **0.598** ✅ |
96
+
97
+ ## Training Results
98
+
99
+ | Metric | Value |
100
+ |--------|-------|
101
+ | Final Loss | 1.23 |
102
+ | Training Steps | 10,000 |
103
+ | Training Time | 2h 7m |
104
+ | Corpus | Balanced multimodal (25% each: text, image, audio, binary) |
105
+ | Corpus Size | 164M chars |
106
+
107
+ ## Usage
108
+
109
+ ```python
110
+ import torch
111
+ from train_lattice_v6 import Braille256LatticeModel, LatticeConfig
112
+
113
+ # Load model
114
+ config = LatticeConfig.from_dict(json.load(open("config.json")))
115
+ model = Braille256LatticeModel(config)
116
+ model.load_state_dict(torch.load("pytorch_model.bin"))
117
+
118
+ # Generate
119
+ input_ids = torch.tensor([[0x28, 0x29, 0x2A]]) # Some Braille tokens
120
+ output = model.generate(input_ids, max_length=100)
121
+ ```
122
+
123
+ ## Training
124
+
125
+ ```bash
126
+ python train_lattice_v6.py \
127
+ --corpus corpus/braille_multimodal_corpus.txt \
128
+ --tokenizer tokenizers/braille_8dot_32k/braille_8dot_32k.model \
129
+ --output models/braille256_v6_lattice \
130
+ --steps 10000
131
+ ```
132
+
133
+ ### Training Options
134
+
135
+ | Flag | Description |
136
+ |------|-------------|
137
+ | `--no-lattice-attention` | Disable lattice attention (ablation) |
138
+ | `--no-lattice-embeddings` | Disable lattice embeddings (ablation) |
139
+ | `--no-morph-regularization` | Disable morphological regularization (ablation) |
140
+
141
+ ## Model Family
142
+
143
+ | Version | Focus | Parameters | Key Feature |
144
+ |---------|-------|------------|-------------|
145
+ | v1-v3 | 6-dot Braille | ~10M | Basic Braille LM |
146
+ | v4 | 8-dot Braille | 29.9M | Full byte encoding |
147
+ | v5 | Multimodal | 11.5M | TEXT/IMAGE/AUDIO/BINARY |
148
+ | **v6** | **Lattice-aware** | **11.5M** | **Hamming attention, morphological regularization, balanced multimodal corpus** |
149
+
150
+ ## Why Lattice-Aware?
151
+
152
+ Standard LLMs treat tokens as arbitrary symbols. braille256-v6 knows that:
153
+
154
+ 1. **Braille cells form a lattice**: 256 elements with meet (∧) and join (∨)
155
+ 2. **Similar cells should have similar representations**: Hamming distance matters
156
+ 3. **Morphological operations preserve meaning**: Erosion/dilation are semantic
157
+ 4. **Tactile quality is measurable**: Haptic metrics evaluate output quality
158
+
159
+ This makes v6 the first LLM designed for **tactile-first AI**.
160
+
161
+ ## Citation
162
+
163
+ ```bibtex
164
+ @misc{braille256v6,
165
+ author = {Barrett, Ryan},
166
+ title = {braille256-v6: Lattice-Aware Multimodal Braille Model},
167
+ year = {2024},
168
+ publisher = {HuggingFace},
169
+ url = {https://huggingface.co/ryanscottbarrett/braille256-v6}
170
+ }
171
+ ```
172
+
173
+ ## License
174
+
175
+ MIT
176
+
177
+ ## Links
178
+
179
+ - [braille256-v5](https://huggingface.co/ryanscottbarrett/braille256-v5)
180
+ - [braille256-v4](https://huggingface.co/ryanscottbarrett/braille256-v4)
181
+ - [Theoretical Paper](docs/BRAILLE_NATIVE_LLM_THEORY.md)
182
+ - [Lattice Theory Implementation](src/braille_lattice_theory.py)
183
+
184
+ ---
185
+
186
+ ⣿ *The first LLM where Braille is not just the output format, but the computational substrate.* ⣿
braille_lattice_theory.py ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Braille Dot-Lattice Theory: Formal Mathematical Framework
4
+
5
+ This module formalizes the missing theoretical components for 8-dot Braille-native LLMs:
6
+
7
+ 1. DOT-LATTICE MORPHOLOGICAL OPERATORS
8
+ - Boolean algebra on the 8-bit Braille lattice (B⁸, ∧, ∨, ¬, ⊕)
9
+ - Morphological operations: erosion, dilation, opening, closing
10
+ - Dot-wise transformations preserving semantic structure
11
+
12
+ 2. MODALITY-INVARIANT BRAILLE REASONING LOOPS
13
+ - Unified representation across text, image, audio, binary
14
+ - Cross-modal attention mechanisms in Braille space
15
+ - Semantic preservation under modality transformation
16
+
17
+ Mathematical Foundation:
18
+ - The 8-dot Braille cell forms a Boolean lattice (B⁸, ≤) where B = {0, 1}
19
+ - Each cell is an 8-dimensional binary vector: c ∈ {0,1}⁸
20
+ - The lattice has 2⁸ = 256 elements with meet (∧) and join (∨) operations
21
+ - This isomorphism to bytes enables direct computational semantics
22
+
23
+ Author: Ryan Barrett & Cascade
24
+ Date: December 2024
25
+ """
26
+
27
+ from __future__ import annotations
28
+ from dataclasses import dataclass, field
29
+ from typing import List, Dict, Tuple, Set, Callable, Optional, Iterator
30
+ from enum import Enum, auto
31
+ import numpy as np
32
+ from functools import reduce
33
+ import operator
34
+
35
+ # =============================================================================
36
+ # SECTION 1: BRAILLE LATTICE FUNDAMENTALS
37
+ # =============================================================================
38
+
39
+ # Unicode range for 8-dot Braille
40
+ BRAILLE_BASE = 0x2800
41
+ BRAILLE_MAX = 0x28FF
42
+
43
+ # Dot position bit values (standard 8-dot layout)
44
+ # Layout: 1 4
45
+ # 2 5
46
+ # 3 6
47
+ # 7 8
48
+ DOT_BITS = {
49
+ 1: 0b00000001, # bit 0
50
+ 2: 0b00000010, # bit 1
51
+ 3: 0b00000100, # bit 2
52
+ 4: 0b00001000, # bit 3
53
+ 5: 0b00010000, # bit 4
54
+ 6: 0b00100000, # bit 5
55
+ 7: 0b01000000, # bit 6
56
+ 8: 0b10000000, # bit 7
57
+ }
58
+
59
+ # Inverse mapping
60
+ BIT_TO_DOT = {v: k for k, v in DOT_BITS.items()}
61
+
62
+
63
+ @dataclass(frozen=True)
64
+ class BrailleCell:
65
+ """
66
+ A single 8-dot Braille cell as an element of the Boolean lattice B⁸.
67
+
68
+ The cell is represented as an 8-bit integer where each bit corresponds
69
+ to a dot position. This enables efficient lattice operations.
70
+
71
+ Properties:
72
+ - Immutable (frozen dataclass)
73
+ - Hashable (can be used in sets/dicts)
74
+ - Supports all Boolean lattice operations
75
+ """
76
+ value: int # 0-255, representing the 8 dots as bits
77
+
78
+ def __post_init__(self):
79
+ if not 0 <= self.value <= 255:
80
+ raise ValueError(f"BrailleCell value must be 0-255, got {self.value}")
81
+
82
+ # --- Lattice Element Properties ---
83
+
84
+ @property
85
+ def dots(self) -> Tuple[int, ...]:
86
+ """Return tuple of active dot numbers (1-8)."""
87
+ return tuple(d for d in range(1, 9) if self.has_dot(d))
88
+
89
+ @property
90
+ def unicode(self) -> str:
91
+ """Return the Unicode Braille character."""
92
+ return chr(BRAILLE_BASE + self.value)
93
+
94
+ @property
95
+ def vector(self) -> np.ndarray:
96
+ """Return as 8-dimensional binary vector."""
97
+ return np.array([(self.value >> i) & 1 for i in range(8)], dtype=np.uint8)
98
+
99
+ @property
100
+ def cardinality(self) -> int:
101
+ """Number of raised dots (Hamming weight)."""
102
+ return bin(self.value).count('1')
103
+
104
+ @property
105
+ def is_bottom(self) -> bool:
106
+ """Check if this is ⊥ (empty cell, no dots)."""
107
+ return self.value == 0
108
+
109
+ @property
110
+ def is_top(self) -> bool:
111
+ """Check if this is ⊤ (all dots raised)."""
112
+ return self.value == 255
113
+
114
+ def has_dot(self, dot: int) -> bool:
115
+ """Check if a specific dot (1-8) is raised."""
116
+ return bool(self.value & DOT_BITS[dot])
117
+
118
+ # --- Boolean Lattice Operations ---
119
+
120
+ def meet(self, other: BrailleCell) -> BrailleCell:
121
+ """
122
+ Lattice meet (∧): Greatest lower bound.
123
+ Equivalent to bitwise AND - keeps only dots present in BOTH cells.
124
+
125
+ Semantically: intersection of dot patterns.
126
+ """
127
+ return BrailleCell(self.value & other.value)
128
+
129
+ def join(self, other: BrailleCell) -> BrailleCell:
130
+ """
131
+ Lattice join (∨): Least upper bound.
132
+ Equivalent to bitwise OR - raises dots present in EITHER cell.
133
+
134
+ Semantically: union of dot patterns.
135
+ """
136
+ return BrailleCell(self.value | other.value)
137
+
138
+ def complement(self) -> BrailleCell:
139
+ """
140
+ Lattice complement (¬): Invert all dots.
141
+ Equivalent to bitwise NOT (masked to 8 bits).
142
+
143
+ Semantically: tactile negative.
144
+ """
145
+ return BrailleCell((~self.value) & 0xFF)
146
+
147
+ def symmetric_difference(self, other: BrailleCell) -> BrailleCell:
148
+ """
149
+ Symmetric difference (⊕): XOR operation.
150
+ Dots present in exactly one of the two cells.
151
+
152
+ Semantically: tactile contrast/difference.
153
+ """
154
+ return BrailleCell(self.value ^ other.value)
155
+
156
+ def implies(self, other: BrailleCell) -> BrailleCell:
157
+ """
158
+ Material implication (→): ¬self ∨ other.
159
+ In lattice terms: self ≤ other iff (self → other) = ⊤
160
+ """
161
+ return self.complement().join(other)
162
+
163
+ # --- Partial Order ---
164
+
165
+ def __le__(self, other: BrailleCell) -> bool:
166
+ """Lattice ordering: self ≤ other iff self ∧ other = self."""
167
+ return (self.value & other.value) == self.value
168
+
169
+ def __lt__(self, other: BrailleCell) -> bool:
170
+ """Strict ordering: self < other iff self ≤ other and self ≠ other."""
171
+ return self <= other and self.value != other.value
172
+
173
+ def __ge__(self, other: BrailleCell) -> bool:
174
+ return other <= self
175
+
176
+ def __gt__(self, other: BrailleCell) -> bool:
177
+ return other < self
178
+
179
+ # --- Operator Overloads ---
180
+
181
+ def __and__(self, other: BrailleCell) -> BrailleCell:
182
+ return self.meet(other)
183
+
184
+ def __or__(self, other: BrailleCell) -> BrailleCell:
185
+ return self.join(other)
186
+
187
+ def __invert__(self) -> BrailleCell:
188
+ return self.complement()
189
+
190
+ def __xor__(self, other: BrailleCell) -> BrailleCell:
191
+ return self.symmetric_difference(other)
192
+
193
+ def __repr__(self) -> str:
194
+ return f"BrailleCell({self.unicode}, dots={self.dots}, value={self.value})"
195
+
196
+ # --- Constructors ---
197
+
198
+ @classmethod
199
+ def from_unicode(cls, char: str) -> BrailleCell:
200
+ """Create from Unicode Braille character."""
201
+ code = ord(char)
202
+ if not BRAILLE_BASE <= code <= BRAILLE_MAX:
203
+ raise ValueError(f"Not a Braille character: {char}")
204
+ return cls(code - BRAILLE_BASE)
205
+
206
+ @classmethod
207
+ def from_dots(cls, *dots: int) -> BrailleCell:
208
+ """Create from dot numbers (1-8)."""
209
+ value = 0
210
+ for d in dots:
211
+ if 1 <= d <= 8:
212
+ value |= DOT_BITS[d]
213
+ return cls(value)
214
+
215
+ @classmethod
216
+ def from_byte(cls, byte: int) -> BrailleCell:
217
+ """Create from byte value (0-255)."""
218
+ return cls(byte & 0xFF)
219
+
220
+ @classmethod
221
+ def from_vector(cls, vec: np.ndarray) -> BrailleCell:
222
+ """Create from 8-dimensional binary vector."""
223
+ value = sum(int(vec[i]) << i for i in range(8))
224
+ return cls(value)
225
+
226
+ @classmethod
227
+ def bottom(cls) -> BrailleCell:
228
+ """Return ⊥ (empty cell)."""
229
+ return cls(0)
230
+
231
+ @classmethod
232
+ def top(cls) -> BrailleCell:
233
+ """Return ⊤ (all dots raised)."""
234
+ return cls(255)
235
+
236
+
237
+ # =============================================================================
238
+ # SECTION 2: DOT-LATTICE MORPHOLOGICAL OPERATORS
239
+ # =============================================================================
240
+
241
+ class MorphologicalOperator(Enum):
242
+ """Morphological operations on the Braille lattice."""
243
+ EROSION = auto() # Shrink patterns
244
+ DILATION = auto() # Expand patterns
245
+ OPENING = auto() # Erosion then dilation (remove small protrusions)
246
+ CLOSING = auto() # Dilation then erosion (fill small gaps)
247
+ GRADIENT = auto() # Dilation - Erosion (edge detection)
248
+ TOP_HAT = auto() # Original - Opening (extract bright details)
249
+ BLACK_HAT = auto() # Closing - Original (extract dark details)
250
+
251
+
252
+ @dataclass
253
+ class StructuringElement:
254
+ """
255
+ A structuring element for morphological operations on Braille cells.
256
+
257
+ In classical morphology, the structuring element defines the neighborhood.
258
+ For Braille, we define it as a set of dot positions that form the "kernel".
259
+
260
+ Common structuring elements:
261
+ - COLUMN_LEFT: dots 1,2,3,7 (left column)
262
+ - COLUMN_RIGHT: dots 4,5,6,8 (right column)
263
+ - ROW_TOP: dots 1,4 (top row)
264
+ - CROSS: dots 2,4,5 (cross pattern)
265
+ - FULL: all 8 dots
266
+ """
267
+ dots: Set[int]
268
+ name: str = ""
269
+
270
+ @property
271
+ def cell(self) -> BrailleCell:
272
+ """Convert to BrailleCell."""
273
+ return BrailleCell.from_dots(*self.dots)
274
+
275
+ # Predefined structuring elements
276
+ @classmethod
277
+ def column_left(cls) -> StructuringElement:
278
+ return cls({1, 2, 3, 7}, "COLUMN_LEFT")
279
+
280
+ @classmethod
281
+ def column_right(cls) -> StructuringElement:
282
+ return cls({4, 5, 6, 8}, "COLUMN_RIGHT")
283
+
284
+ @classmethod
285
+ def row_top(cls) -> StructuringElement:
286
+ return cls({1, 4}, "ROW_TOP")
287
+
288
+ @classmethod
289
+ def row_middle(cls) -> StructuringElement:
290
+ return cls({2, 5}, "ROW_MIDDLE")
291
+
292
+ @classmethod
293
+ def row_bottom(cls) -> StructuringElement:
294
+ return cls({3, 6}, "ROW_BOTTOM")
295
+
296
+ @classmethod
297
+ def row_extension(cls) -> StructuringElement:
298
+ return cls({7, 8}, "ROW_EXTENSION")
299
+
300
+ @classmethod
301
+ def cross(cls) -> StructuringElement:
302
+ return cls({2, 4, 5}, "CROSS")
303
+
304
+ @classmethod
305
+ def full(cls) -> StructuringElement:
306
+ return cls({1, 2, 3, 4, 5, 6, 7, 8}, "FULL")
307
+
308
+ @classmethod
309
+ def six_dot(cls) -> StructuringElement:
310
+ """Traditional 6-dot Braille subset."""
311
+ return cls({1, 2, 3, 4, 5, 6}, "SIX_DOT")
312
+
313
+
314
+ class BrailleMorphology:
315
+ """
316
+ Morphological operators on the Braille dot-lattice.
317
+
318
+ These operators enable pattern transformation while preserving
319
+ structural relationships in the lattice.
320
+
321
+ Key insight: Morphological operations on Braille cells can be
322
+ computed efficiently using Boolean operations on the underlying
323
+ 8-bit representation.
324
+ """
325
+
326
+ @staticmethod
327
+ def erode(cell: BrailleCell, se: StructuringElement) -> BrailleCell:
328
+ """
329
+ Erosion: Keep only dots that have ALL structuring element dots present.
330
+
331
+ ε_B(X) = {x : B_x ⊆ X}
332
+
333
+ For single cell: result has dot d iff for all dots s in SE,
334
+ the cell has dot at position (d + s - 1) mod 8 + 1
335
+
336
+ Simplified for single cell: AND with structuring element.
337
+ """
338
+ return cell & se.cell
339
+
340
+ @staticmethod
341
+ def dilate(cell: BrailleCell, se: StructuringElement) -> BrailleCell:
342
+ """
343
+ Dilation: Raise dots if ANY structuring element dot is present.
344
+
345
+ δ_B(X) = {x : B_x ∩ X ≠ ∅}
346
+
347
+ For single cell: OR with structuring element.
348
+ """
349
+ return cell | se.cell
350
+
351
+ @staticmethod
352
+ def opening(cell: BrailleCell, se: StructuringElement) -> BrailleCell:
353
+ """
354
+ Opening: Erosion followed by dilation.
355
+
356
+ γ_B(X) = δ_B(ε_B(X))
357
+
358
+ Effect: Removes small protrusions, smooths from outside.
359
+ """
360
+ eroded = BrailleMorphology.erode(cell, se)
361
+ return BrailleMorphology.dilate(eroded, se)
362
+
363
+ @staticmethod
364
+ def closing(cell: BrailleCell, se: StructuringElement) -> BrailleCell:
365
+ """
366
+ Closing: Dilation followed by erosion.
367
+
368
+ φ_B(X) = ε_B(δ_B(X))
369
+
370
+ Effect: Fills small gaps, smooths from inside.
371
+ """
372
+ dilated = BrailleMorphology.dilate(cell, se)
373
+ return BrailleMorphology.erode(dilated, se)
374
+
375
+ @staticmethod
376
+ def gradient(cell: BrailleCell, se: StructuringElement) -> BrailleCell:
377
+ """
378
+ Morphological gradient: Dilation - Erosion (via XOR).
379
+
380
+ ρ_B(X) = δ_B(X) - ε_B(X)
381
+
382
+ Effect: Edge detection - dots that differ between dilation and erosion.
383
+ """
384
+ dilated = BrailleMorphology.dilate(cell, se)
385
+ eroded = BrailleMorphology.erode(cell, se)
386
+ return dilated ^ eroded
387
+
388
+ @staticmethod
389
+ def top_hat(cell: BrailleCell, se: StructuringElement) -> BrailleCell:
390
+ """
391
+ Top-hat transform: Original - Opening.
392
+
393
+ Effect: Extracts bright details smaller than structuring element.
394
+ """
395
+ opened = BrailleMorphology.opening(cell, se)
396
+ return cell ^ opened # Difference via XOR
397
+
398
+ @staticmethod
399
+ def black_hat(cell: BrailleCell, se: StructuringElement) -> BrailleCell:
400
+ """
401
+ Black-hat transform: Closing - Original.
402
+
403
+ Effect: Extracts dark details smaller than structuring element.
404
+ """
405
+ closed = BrailleMorphology.closing(cell, se)
406
+ return closed ^ cell # Difference via XOR
407
+
408
+ @staticmethod
409
+ def hit_or_miss(cell: BrailleCell,
410
+ foreground: StructuringElement,
411
+ background: StructuringElement) -> bool:
412
+ """
413
+ Hit-or-miss transform: Pattern matching.
414
+
415
+ Returns True iff:
416
+ - All foreground dots are present in cell
417
+ - All background dots are absent from cell
418
+
419
+ This is the foundation for pattern recognition in Braille.
420
+ """
421
+ fg_match = (cell & foreground.cell) == foreground.cell
422
+ bg_match = (cell & background.cell).is_bottom
423
+ return fg_match and bg_match
424
+
425
+
426
+ # =============================================================================
427
+ # SECTION 3: BRAILLE SEQUENCE MORPHOLOGY
428
+ # =============================================================================
429
+
430
+ @dataclass
431
+ class BrailleSequence:
432
+ """
433
+ A sequence of Braille cells with morphological operations.
434
+
435
+ This extends single-cell morphology to sequences, enabling
436
+ operations on Braille text/data streams.
437
+ """
438
+ cells: List[BrailleCell] = field(default_factory=list)
439
+
440
+ def __len__(self) -> int:
441
+ return len(self.cells)
442
+
443
+ def __getitem__(self, idx: int) -> BrailleCell:
444
+ return self.cells[idx]
445
+
446
+ def __iter__(self) -> Iterator[BrailleCell]:
447
+ return iter(self.cells)
448
+
449
+ @property
450
+ def unicode(self) -> str:
451
+ """Return as Unicode Braille string."""
452
+ return ''.join(c.unicode for c in self.cells)
453
+
454
+ @property
455
+ def bytes(self) -> bytes:
456
+ """Return as byte sequence."""
457
+ return bytes(c.value for c in self.cells)
458
+
459
+ def apply(self, op: Callable[[BrailleCell], BrailleCell]) -> BrailleSequence:
460
+ """Apply a cell-wise operation to the sequence."""
461
+ return BrailleSequence([op(c) for c in self.cells])
462
+
463
+ def apply_morphology(self,
464
+ operator: MorphologicalOperator,
465
+ se: StructuringElement) -> BrailleSequence:
466
+ """Apply morphological operation to each cell."""
467
+ ops = {
468
+ MorphologicalOperator.EROSION: BrailleMorphology.erode,
469
+ MorphologicalOperator.DILATION: BrailleMorphology.dilate,
470
+ MorphologicalOperator.OPENING: BrailleMorphology.opening,
471
+ MorphologicalOperator.CLOSING: BrailleMorphology.closing,
472
+ MorphologicalOperator.GRADIENT: BrailleMorphology.gradient,
473
+ MorphologicalOperator.TOP_HAT: BrailleMorphology.top_hat,
474
+ MorphologicalOperator.BLACK_HAT: BrailleMorphology.black_hat,
475
+ }
476
+ op_func = ops[operator]
477
+ return BrailleSequence([op_func(c, se) for c in self.cells])
478
+
479
+ def convolve(self, kernel: List[BrailleCell],
480
+ op: Callable[[BrailleCell, BrailleCell], BrailleCell] = lambda a, b: a & b) -> BrailleSequence:
481
+ """
482
+ Convolve sequence with a kernel using specified operation.
483
+
484
+ This enables sliding-window pattern matching and transformation.
485
+ """
486
+ if not kernel:
487
+ return self
488
+
489
+ k_len = len(kernel)
490
+ result = []
491
+
492
+ for i in range(len(self.cells)):
493
+ # Apply kernel centered at position i
494
+ acc = BrailleCell.bottom()
495
+ for j, k_cell in enumerate(kernel):
496
+ idx = i - k_len // 2 + j
497
+ if 0 <= idx < len(self.cells):
498
+ acc = acc | op(self.cells[idx], k_cell)
499
+ result.append(acc)
500
+
501
+ return BrailleSequence(result)
502
+
503
+ def reduce(self,
504
+ op: Callable[[BrailleCell, BrailleCell], BrailleCell] = lambda a, b: a | b) -> BrailleCell:
505
+ """Reduce sequence to single cell using operation."""
506
+ if not self.cells:
507
+ return BrailleCell.bottom()
508
+ return reduce(op, self.cells)
509
+
510
+ @classmethod
511
+ def from_unicode(cls, text: str) -> BrailleSequence:
512
+ """Create from Unicode Braille string."""
513
+ cells = []
514
+ for char in text:
515
+ code = ord(char)
516
+ if BRAILLE_BASE <= code <= BRAILLE_MAX:
517
+ cells.append(BrailleCell(code - BRAILLE_BASE))
518
+ return cls(cells)
519
+
520
+ @classmethod
521
+ def from_bytes(cls, data: bytes) -> BrailleSequence:
522
+ """Create from byte sequence (direct mapping)."""
523
+ return cls([BrailleCell(b) for b in data])
524
+
525
+
526
+ # =============================================================================
527
+ # SECTION 4: MODALITY-INVARIANT BRAILLE REPRESENTATION
528
+ # =============================================================================
529
+
530
+ class Modality(Enum):
531
+ """Supported modalities for Braille encoding."""
532
+ TEXT = auto()
533
+ IMAGE = auto()
534
+ AUDIO = auto()
535
+ BINARY = auto()
536
+ VIDEO = auto()
537
+ SEMANTIC = auto() # Abstract semantic content
538
+
539
+
540
+ # Modality headers (from braille256-v5)
541
+ MODALITY_HEADERS = {
542
+ Modality.TEXT: BrailleCell.from_dots(1, 2, 3, 4, 5, 6, 7, 8), # ⣿ + ⠁ = ⣿⠁
543
+ Modality.IMAGE: BrailleCell.from_dots(1, 2, 3, 4, 5, 6, 7, 8), # ⣿ + ⠃ = ⣿⠃
544
+ Modality.AUDIO: BrailleCell.from_dots(1, 2, 3, 4, 5, 6, 7, 8), # ⣿ + ⠇ = ⣿⠇
545
+ Modality.BINARY: BrailleCell.from_dots(1, 2, 3, 4, 5, 6, 7, 8), # ⣿ + ⠏ = ⣿⠏
546
+ Modality.VIDEO: BrailleCell.from_dots(1, 2, 3, 4, 5, 6, 7, 8), # ⣿ + ⠗ = ⣿⠗
547
+ }
548
+
549
+
550
+ @dataclass
551
+ class ModalityInvariantRepresentation:
552
+ """
553
+ A modality-invariant representation in Braille space.
554
+
555
+ Key insight: All modalities can be encoded as byte sequences,
556
+ and all byte sequences map bijectively to 8-dot Braille.
557
+ Therefore, Braille provides a universal representation space.
558
+
559
+ The representation consists of:
560
+ 1. Modality header (identifies source modality)
561
+ 2. Semantic embedding (modality-invariant meaning)
562
+ 3. Raw Braille sequence (the actual data)
563
+
564
+ Cross-modal operations preserve semantic content while
565
+ allowing modality-specific transformations.
566
+ """
567
+ modality: Modality
568
+ sequence: BrailleSequence
569
+ semantic_embedding: Optional[np.ndarray] = None # d-dimensional semantic vector
570
+ metadata: Dict = field(default_factory=dict)
571
+
572
+ @property
573
+ def header(self) -> BrailleCell:
574
+ """Get modality header cell."""
575
+ return MODALITY_HEADERS.get(self.modality, BrailleCell.top())
576
+
577
+ def to_semantic_space(self, encoder: Callable[[BrailleSequence], np.ndarray]) -> np.ndarray:
578
+ """
579
+ Project Braille sequence to semantic embedding space.
580
+
581
+ This is where the LLM's learned embeddings come in.
582
+ The encoder maps Braille tokens to semantic vectors.
583
+ """
584
+ if self.semantic_embedding is None:
585
+ self.semantic_embedding = encoder(self.sequence)
586
+ return self.semantic_embedding
587
+
588
+ def transform_modality(self,
589
+ target: Modality,
590
+ transformer: Callable[[BrailleSequence, Modality, Modality], BrailleSequence]
591
+ ) -> ModalityInvariantRepresentation:
592
+ """
593
+ Transform to a different modality while preserving semantics.
594
+
595
+ The transformer function handles modality-specific conversion
596
+ while the semantic embedding remains invariant.
597
+ """
598
+ new_sequence = transformer(self.sequence, self.modality, target)
599
+ return ModalityInvariantRepresentation(
600
+ modality=target,
601
+ sequence=new_sequence,
602
+ semantic_embedding=self.semantic_embedding, # Preserved!
603
+ metadata={**self.metadata, 'source_modality': self.modality}
604
+ )
605
+
606
+
607
+ # =============================================================================
608
+ # SECTION 5: BRAILLE REASONING LOOPS
609
+ # =============================================================================
610
+
611
+ @dataclass
612
+ class ReasoningState:
613
+ """
614
+ State of a Braille reasoning loop.
615
+
616
+ The reasoning loop operates entirely in Braille space:
617
+ 1. Input: Braille sequence (any modality)
618
+ 2. Transform: Apply morphological/semantic operations
619
+ 3. Attend: Cross-modal attention in Braille space
620
+ 4. Output: Braille sequence (any modality)
621
+
622
+ This enables modality-invariant reasoning where the same
623
+ operations work regardless of input/output modality.
624
+ """
625
+ sequence: BrailleSequence
626
+ attention_weights: Optional[np.ndarray] = None
627
+ hidden_state: Optional[np.ndarray] = None
628
+ step: int = 0
629
+
630
+ def apply_attention(self,
631
+ query: BrailleSequence,
632
+ key: BrailleSequence,
633
+ value: BrailleSequence) -> BrailleSequence:
634
+ """
635
+ Cross-modal attention in Braille space.
636
+
637
+ Attention is computed on the lattice structure:
638
+ - Query, Key, Value are all Braille sequences
639
+ - Similarity is measured via lattice distance
640
+ - Output is weighted combination in Braille space
641
+
642
+ Lattice distance: d(a, b) = |a ⊕ b| (Hamming distance)
643
+ """
644
+ if len(query) == 0 or len(key) == 0:
645
+ return value
646
+
647
+ # Compute attention scores based on lattice similarity
648
+ scores = np.zeros((len(query), len(key)))
649
+ for i, q in enumerate(query):
650
+ for j, k in enumerate(key):
651
+ # Similarity = 8 - Hamming distance (higher = more similar)
652
+ diff = q ^ k
653
+ scores[i, j] = 8 - diff.cardinality
654
+
655
+ # Softmax normalization
656
+ scores = np.exp(scores - scores.max(axis=1, keepdims=True))
657
+ self.attention_weights = scores / scores.sum(axis=1, keepdims=True)
658
+
659
+ # Weighted combination of values
660
+ result = []
661
+ for i in range(len(query)):
662
+ # Combine values weighted by attention
663
+ combined = BrailleCell.bottom()
664
+ for j, v in enumerate(value):
665
+ if self.attention_weights[i, j] > 0.1: # Threshold
666
+ combined = combined | v
667
+ result.append(combined)
668
+
669
+ return BrailleSequence(result)
670
+
671
+
672
+ class BrailleReasoningLoop:
673
+ """
674
+ A modality-invariant reasoning loop operating in Braille space.
675
+
676
+ The loop implements the following cycle:
677
+
678
+ 1. ENCODE: Any modality → Braille sequence
679
+ 2. TRANSFORM: Morphological operations on Braille
680
+ 3. ATTEND: Cross-sequence attention in lattice space
681
+ 4. REASON: Apply learned transformations (LLM layers)
682
+ 5. DECODE: Braille sequence → Any modality
683
+
684
+ Key property: Steps 2-4 are MODALITY-INVARIANT.
685
+ The same operations work for text, images, audio, etc.
686
+ """
687
+
688
+ def __init__(self,
689
+ hidden_dim: int = 256,
690
+ num_heads: int = 8,
691
+ morphology_se: StructuringElement = None):
692
+ self.hidden_dim = hidden_dim
693
+ self.num_heads = num_heads
694
+ self.morphology_se = morphology_se or StructuringElement.six_dot()
695
+ self.state = None
696
+
697
+ def encode(self,
698
+ data: bytes,
699
+ modality: Modality) -> ModalityInvariantRepresentation:
700
+ """
701
+ Encode any modality to Braille representation.
702
+
703
+ This is the entry point: raw bytes → Braille sequence.
704
+ The modality header is prepended for downstream processing.
705
+ """
706
+ sequence = BrailleSequence.from_bytes(data)
707
+ return ModalityInvariantRepresentation(
708
+ modality=modality,
709
+ sequence=sequence
710
+ )
711
+
712
+ def transform(self,
713
+ rep: ModalityInvariantRepresentation,
714
+ operator: MorphologicalOperator = MorphologicalOperator.OPENING
715
+ ) -> ModalityInvariantRepresentation:
716
+ """
717
+ Apply morphological transformation.
718
+
719
+ This step is modality-invariant: the same operation
720
+ works regardless of whether the input is text, image, etc.
721
+ """
722
+ transformed = rep.sequence.apply_morphology(operator, self.morphology_se)
723
+ return ModalityInvariantRepresentation(
724
+ modality=rep.modality,
725
+ sequence=transformed,
726
+ semantic_embedding=rep.semantic_embedding,
727
+ metadata=rep.metadata
728
+ )
729
+
730
+ def attend(self,
731
+ query_rep: ModalityInvariantRepresentation,
732
+ context_rep: ModalityInvariantRepresentation
733
+ ) -> ModalityInvariantRepresentation:
734
+ """
735
+ Cross-modal attention between two representations.
736
+
737
+ This enables reasoning across modalities:
738
+ - Text attending to image
739
+ - Audio attending to text
740
+ - Any modality attending to any other
741
+
742
+ The attention operates in Braille lattice space.
743
+ """
744
+ if self.state is None:
745
+ self.state = ReasoningState(sequence=query_rep.sequence)
746
+
747
+ attended = self.state.apply_attention(
748
+ query=query_rep.sequence,
749
+ key=context_rep.sequence,
750
+ value=context_rep.sequence
751
+ )
752
+
753
+ return ModalityInvariantRepresentation(
754
+ modality=query_rep.modality,
755
+ sequence=attended,
756
+ semantic_embedding=query_rep.semantic_embedding,
757
+ metadata={**query_rep.metadata, 'attended_modality': context_rep.modality}
758
+ )
759
+
760
+ def reason(self,
761
+ rep: ModalityInvariantRepresentation,
762
+ transform_fn: Callable[[BrailleSequence], BrailleSequence] = None
763
+ ) -> ModalityInvariantRepresentation:
764
+ """
765
+ Apply learned reasoning transformation.
766
+
767
+ In a full LLM, this would be the transformer layers.
768
+ Here we provide a hook for custom transformations.
769
+
770
+ The key insight: reasoning happens in Braille space,
771
+ making it inherently modality-invariant.
772
+ """
773
+ if transform_fn is None:
774
+ # Default: identity with morphological smoothing
775
+ transform_fn = lambda seq: seq.apply_morphology(
776
+ MorphologicalOperator.CLOSING,
777
+ self.morphology_se
778
+ )
779
+
780
+ reasoned = transform_fn(rep.sequence)
781
+
782
+ return ModalityInvariantRepresentation(
783
+ modality=rep.modality,
784
+ sequence=reasoned,
785
+ semantic_embedding=rep.semantic_embedding,
786
+ metadata=rep.metadata
787
+ )
788
+
789
+ def decode(self,
790
+ rep: ModalityInvariantRepresentation,
791
+ target_modality: Modality = None
792
+ ) -> bytes:
793
+ """
794
+ Decode Braille representation to bytes.
795
+
796
+ This is the exit point: Braille sequence → raw bytes.
797
+ The target modality determines any post-processing.
798
+ """
799
+ return rep.sequence.bytes
800
+
801
+ def full_loop(self,
802
+ input_data: bytes,
803
+ input_modality: Modality,
804
+ context_data: bytes = None,
805
+ context_modality: Modality = None,
806
+ output_modality: Modality = None
807
+ ) -> bytes:
808
+ """
809
+ Execute a complete reasoning loop.
810
+
811
+ Input → Encode → Transform → Attend → Reason → Decode → Output
812
+
813
+ All intermediate steps are modality-invariant.
814
+ """
815
+ # Encode input
816
+ rep = self.encode(input_data, input_modality)
817
+
818
+ # Transform
819
+ rep = self.transform(rep)
820
+
821
+ # Attend to context if provided
822
+ if context_data is not None:
823
+ context_rep = self.encode(
824
+ context_data,
825
+ context_modality or input_modality
826
+ )
827
+ rep = self.attend(rep, context_rep)
828
+
829
+ # Reason
830
+ rep = self.reason(rep)
831
+
832
+ # Decode
833
+ return self.decode(rep, output_modality or input_modality)
834
+
835
+
836
+ # =============================================================================
837
+ # SECTION 6: LATTICE DISTANCE METRICS
838
+ # =============================================================================
839
+
840
+ class BrailleMetrics:
841
+ """
842
+ Distance and similarity metrics on the Braille lattice.
843
+
844
+ These metrics enable:
845
+ - Semantic similarity measurement
846
+ - Clustering in Braille space
847
+ - Loss functions for training
848
+ """
849
+
850
+ @staticmethod
851
+ def hamming_distance(a: BrailleCell, b: BrailleCell) -> int:
852
+ """
853
+ Hamming distance: number of differing dots.
854
+
855
+ d_H(a, b) = |a ⊕ b| = popcount(a XOR b)
856
+
857
+ Range: [0, 8]
858
+ """
859
+ return (a ^ b).cardinality
860
+
861
+ @staticmethod
862
+ def jaccard_similarity(a: BrailleCell, b: BrailleCell) -> float:
863
+ """
864
+ Jaccard similarity: intersection over union.
865
+
866
+ J(a, b) = |a ∧ b| / |a ∨ b|
867
+
868
+ Range: [0, 1]
869
+ """
870
+ intersection = (a & b).cardinality
871
+ union = (a | b).cardinality
872
+ if union == 0:
873
+ return 1.0 # Both empty
874
+ return intersection / union
875
+
876
+ @staticmethod
877
+ def lattice_distance(a: BrailleCell, b: BrailleCell) -> int:
878
+ """
879
+ Lattice distance: length of shortest path in Hasse diagram.
880
+
881
+ For Boolean lattice: d_L(a, b) = |a ⊕ b| (same as Hamming)
882
+ """
883
+ return BrailleMetrics.hamming_distance(a, b)
884
+
885
+ @staticmethod
886
+ def semantic_distance(a: BrailleCell, b: BrailleCell,
887
+ embeddings: Dict[int, np.ndarray] = None) -> float:
888
+ """
889
+ Semantic distance using learned embeddings.
890
+
891
+ If embeddings are provided, uses cosine distance in embedding space.
892
+ Otherwise, falls back to normalized Hamming distance.
893
+ """
894
+ if embeddings is not None and a.value in embeddings and b.value in embeddings:
895
+ vec_a = embeddings[a.value]
896
+ vec_b = embeddings[b.value]
897
+ cos_sim = np.dot(vec_a, vec_b) / (np.linalg.norm(vec_a) * np.linalg.norm(vec_b))
898
+ return 1.0 - cos_sim
899
+ else:
900
+ return BrailleMetrics.hamming_distance(a, b) / 8.0
901
+
902
+ @staticmethod
903
+ def sequence_distance(a: BrailleSequence, b: BrailleSequence,
904
+ cell_metric: Callable[[BrailleCell, BrailleCell], float] = None
905
+ ) -> float:
906
+ """
907
+ Distance between two Braille sequences.
908
+
909
+ Uses dynamic time warping or simple alignment depending on lengths.
910
+ """
911
+ if cell_metric is None:
912
+ cell_metric = lambda x, y: BrailleMetrics.hamming_distance(x, y) / 8.0
913
+
914
+ if len(a) == 0 and len(b) == 0:
915
+ return 0.0
916
+ if len(a) == 0 or len(b) == 0:
917
+ return 1.0
918
+
919
+ # Simple aligned distance for equal lengths
920
+ if len(a) == len(b):
921
+ total = sum(cell_metric(a[i], b[i]) for i in range(len(a)))
922
+ return total / len(a)
923
+
924
+ # DTW for unequal lengths
925
+ n, m = len(a), len(b)
926
+ dtw = np.full((n + 1, m + 1), np.inf)
927
+ dtw[0, 0] = 0
928
+
929
+ for i in range(1, n + 1):
930
+ for j in range(1, m + 1):
931
+ cost = cell_metric(a[i-1], b[j-1])
932
+ dtw[i, j] = cost + min(dtw[i-1, j], dtw[i, j-1], dtw[i-1, j-1])
933
+
934
+ return dtw[n, m] / max(n, m)
935
+
936
+
937
+ # =============================================================================
938
+ # SECTION 7: DEMONSTRATION AND TESTING
939
+ # =============================================================================
940
+
941
+ def demonstrate_lattice_operations():
942
+ """Demonstrate the Braille lattice operations."""
943
+ print("=" * 60)
944
+ print("BRAILLE DOT-LATTICE THEORY DEMONSTRATION")
945
+ print("=" * 60)
946
+
947
+ # Create some cells
948
+ cell_a = BrailleCell.from_dots(1, 2, 4) # ⠋
949
+ cell_b = BrailleCell.from_dots(2, 4, 5) # ⠚
950
+
951
+ print(f"\n1. BASIC LATTICE OPERATIONS")
952
+ print(f" Cell A: {cell_a}")
953
+ print(f" Cell B: {cell_b}")
954
+ print(f" A ∧ B (meet): {cell_a & cell_b}")
955
+ print(f" A ∨ B (join): {cell_a | cell_b}")
956
+ print(f" ¬A (complement): {~cell_a}")
957
+ print(f" A ⊕ B (xor): {cell_a ^ cell_b}")
958
+ print(f" A ≤ B: {cell_a <= cell_b}")
959
+
960
+ # Morphological operations
961
+ print(f"\n2. MORPHOLOGICAL OPERATIONS")
962
+ se = StructuringElement.column_left()
963
+ print(f" Structuring element: {se.name} = dots {se.dots}")
964
+ print(f" Erosion(A, SE): {BrailleMorphology.erode(cell_a, se)}")
965
+ print(f" Dilation(A, SE): {BrailleMorphology.dilate(cell_a, se)}")
966
+ print(f" Opening(A, SE): {BrailleMorphology.opening(cell_a, se)}")
967
+ print(f" Closing(A, SE): {BrailleMorphology.closing(cell_a, se)}")
968
+ print(f" Gradient(A, SE): {BrailleMorphology.gradient(cell_a, se)}")
969
+
970
+ # Sequence operations
971
+ print(f"\n3. SEQUENCE OPERATIONS")
972
+ text = "Hello"
973
+ seq = BrailleSequence.from_bytes(text.encode())
974
+ print(f" Input text: '{text}'")
975
+ print(f" As Braille: {seq.unicode}")
976
+ print(f" Dilated: {seq.apply_morphology(MorphologicalOperator.DILATION, se).unicode}")
977
+ print(f" Eroded: {seq.apply_morphology(MorphologicalOperator.EROSION, se).unicode}")
978
+
979
+ # Distance metrics
980
+ print(f"\n4. LATTICE METRICS")
981
+ print(f" Hamming(A, B): {BrailleMetrics.hamming_distance(cell_a, cell_b)}")
982
+ print(f" Jaccard(A, B): {BrailleMetrics.jaccard_similarity(cell_a, cell_b):.3f}")
983
+ print(f" Lattice(A, B): {BrailleMetrics.lattice_distance(cell_a, cell_b)}")
984
+
985
+ # Modality-invariant reasoning
986
+ print(f"\n5. MODALITY-INVARIANT REASONING LOOP")
987
+ loop = BrailleReasoningLoop()
988
+
989
+ input_text = b"Test input"
990
+ context = b"Context data"
991
+
992
+ output = loop.full_loop(
993
+ input_data=input_text,
994
+ input_modality=Modality.TEXT,
995
+ context_data=context,
996
+ context_modality=Modality.TEXT
997
+ )
998
+
999
+ print(f" Input: {input_text}")
1000
+ print(f" Context: {context}")
1001
+ print(f" Output: {output}")
1002
+ print(f" (Output differs due to morphological transformations)")
1003
+
1004
+ print("\n" + "=" * 60)
1005
+ print("THEORETICAL FRAMEWORK COMPLETE")
1006
+ print("=" * 60)
1007
+
1008
+
1009
+ if __name__ == "__main__":
1010
+ demonstrate_lattice_operations()
config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 32000,
3
+ "hidden_size": 256,
4
+ "num_layers": 4,
5
+ "num_heads": 4,
6
+ "intermediate_size": 1024,
7
+ "max_position_embeddings": 512,
8
+ "dropout": 0.1,
9
+ "use_lattice_attention": true,
10
+ "lattice_attention_weight": 0.4,
11
+ "use_morphological_regularization": true,
12
+ "morphological_weight": 0.005000000000000001,
13
+ "use_lattice_embeddings": true,
14
+ "structuring_element": "six_dot",
15
+ "embedding_dropout": 0.15,
16
+ "modality_embedding_dim": 32,
17
+ "num_modalities": 5
18
+ }
final_eval.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "lattice_coherence": 0.7428203609814639,
3
+ "morphological_stability": 0.4530333221703768,
4
+ "haptic_score": 0.5979268415759204
5
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7721b549cb7724d1f110f0221b98044a1472b95ab01ac4a586dd2ae2cbfa0704
3
+ size 47273908
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec5b8b6fbd8985a97c74d377a83f58ff59f3860d02a343eb15146da467da40ae
3
+ size 1155082
train_lattice_v6.py ADDED
@@ -0,0 +1,990 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train braille256-v6: Lattice-Aware Multimodal Braille Model
4
+
5
+ This is the first LLM with explicit dot-lattice structure in its architecture:
6
+ 1. Lattice-aware attention (Hamming-based similarity)
7
+ 2. Morphological regularization (erosion/dilation as inductive bias)
8
+ 3. Lattice-structured embeddings (respecting Boolean algebra)
9
+ 4. Modality-invariant reasoning loops
10
+
11
+ Building on v5's multimodal foundation, v6 integrates the formal theory
12
+ from braille_lattice_theory.py into the training pipeline.
13
+
14
+ Author: Ryan Barrett & Cascade
15
+ Date: December 2024
16
+ """
17
+
18
+ import os
19
+ import sys
20
+ import json
21
+ import math
22
+ import logging
23
+ import argparse
24
+ from dataclasses import dataclass, field
25
+ from typing import Optional, List, Tuple, Dict
26
+ from enum import Enum
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ from torch.utils.data import Dataset, DataLoader
32
+ from tqdm import tqdm
33
+ import numpy as np
34
+
35
+ import sentencepiece as spm
36
+
37
+ # Import lattice theory
38
+ from braille_lattice_theory import (
39
+ BrailleCell, BrailleMorphology, BrailleSequence,
40
+ StructuringElement, MorphologicalOperator, BrailleMetrics,
41
+ BRAILLE_BASE, BRAILLE_MAX
42
+ )
43
+
44
+ logging.basicConfig(level=logging.INFO)
45
+ logger = logging.getLogger(__name__)
46
+
47
+ # =============================================================================
48
+ # Configuration
49
+ # =============================================================================
50
+
51
+ @dataclass
52
+ class LatticeConfig:
53
+ """Configuration for lattice-aware model."""
54
+ # Model architecture
55
+ vocab_size: int = 32000
56
+ hidden_size: int = 256
57
+ num_layers: int = 4
58
+ num_heads: int = 4
59
+ intermediate_size: int = 1024
60
+ max_position_embeddings: int = 512
61
+ dropout: float = 0.1
62
+
63
+ # Lattice-specific settings
64
+ use_lattice_attention: bool = True
65
+ lattice_attention_weight: float = 0.4 # Blend with standard attention (increased)
66
+ use_morphological_regularization: bool = True
67
+ morphological_weight: float = 0.05 # Regularization strength (increased 5x)
68
+ use_lattice_embeddings: bool = True
69
+ structuring_element: str = "six_dot" # Which SE to use
70
+ embedding_dropout: float = 0.15 # Dropout on embeddings to prevent overfitting
71
+
72
+ # Modality settings
73
+ modality_embedding_dim: int = 32
74
+ num_modalities: int = 5 # TEXT, IMAGE, AUDIO, BINARY, VIDEO
75
+
76
+ def to_dict(self):
77
+ return {k: v for k, v in self.__dict__.items()}
78
+
79
+ @classmethod
80
+ def from_dict(cls, d):
81
+ return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
82
+
83
+
84
+ # =============================================================================
85
+ # Lattice-Aware Attention
86
+ # =============================================================================
87
+
88
+ class LatticeAttention(nn.Module):
89
+ """
90
+ Attention mechanism that incorporates Braille lattice structure.
91
+
92
+ Key innovation: Combines standard softmax attention with lattice-based
93
+ similarity computed via Hamming distance on the underlying Braille cells.
94
+
95
+ For tokens that map to Braille cells, we compute:
96
+ lattice_sim(a, b) = 8 - Hamming(a ⊕ b)
97
+
98
+ This is then blended with standard QK^T attention.
99
+ """
100
+
101
+ def __init__(self, config: LatticeConfig):
102
+ super().__init__()
103
+ self.config = config
104
+ self.num_heads = config.num_heads
105
+ self.head_dim = config.hidden_size // config.num_heads
106
+ self.lattice_weight = config.lattice_attention_weight
107
+
108
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
109
+ self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
110
+ self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
111
+ self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
112
+ self.dropout = nn.Dropout(config.dropout)
113
+
114
+ # Learnable lattice attention temperature
115
+ self.lattice_temperature = nn.Parameter(torch.ones(1))
116
+
117
+ # Precompute Hamming distance matrix for all 256 Braille cells
118
+ self._precompute_hamming_matrix()
119
+
120
+ def _precompute_hamming_matrix(self):
121
+ """Precompute pairwise Hamming distances for efficiency."""
122
+ hamming = torch.zeros(256, 256)
123
+ for i in range(256):
124
+ for j in range(256):
125
+ # XOR and count bits
126
+ xor = i ^ j
127
+ hamming[i, j] = bin(xor).count('1')
128
+
129
+ # Convert to similarity: 8 - hamming (range [0, 8])
130
+ self.register_buffer('lattice_similarity', 8 - hamming)
131
+
132
+ def _get_braille_values(self, token_ids: torch.Tensor, sp_model) -> torch.Tensor:
133
+ """
134
+ Extract Braille cell values from token IDs (vectorized).
135
+
136
+ For tokens that decode to Braille characters, return their cell value.
137
+ For others, return -1 (will be masked in attention).
138
+ """
139
+ # Vectorized: tokens < 256 are treated as Braille cells
140
+ braille_values = torch.where(
141
+ token_ids < 256,
142
+ token_ids,
143
+ torch.full_like(token_ids, -1)
144
+ )
145
+ return braille_values
146
+
147
+ def compute_lattice_attention(self, braille_values: torch.Tensor) -> torch.Tensor:
148
+ """
149
+ Compute attention scores based on lattice similarity (fully vectorized).
150
+
151
+ Returns attention logits of shape (B, T, T).
152
+ """
153
+ B, T = braille_values.shape
154
+
155
+ # Mask for valid Braille values
156
+ valid_mask = (braille_values >= 0).float()
157
+
158
+ # Clamp to valid range for indexing
159
+ safe_values = braille_values.clamp(0, 255).long()
160
+
161
+ # Vectorized lookup: use advanced indexing
162
+ # Flatten batch for indexing, then reshape
163
+ flat_values = safe_values.view(-1) # (B*T,)
164
+
165
+ # Get similarity rows for each token
166
+ # lattice_similarity is (256, 256), we want (B*T, 256)
167
+ sim_rows = self.lattice_similarity[flat_values] # (B*T, 256)
168
+
169
+ # Now index columns: for each pair (i, j), get sim_rows[i, safe_values[j]]
170
+ # Reshape to (B, T, 256) then gather along last dim
171
+ sim_rows = sim_rows.view(B, T, 256) # (B, T, 256)
172
+
173
+ # Expand safe_values for gathering: (B, T) -> (B, T, T)
174
+ indices = safe_values.unsqueeze(1).expand(B, T, T) # (B, T, T)
175
+
176
+ # Gather: for each (b, i, j), get sim_rows[b, i, safe_values[b, j]]
177
+ lattice_attn = torch.gather(sim_rows, 2, indices.transpose(1, 2)).transpose(1, 2)
178
+
179
+ # Apply temperature
180
+ lattice_attn = lattice_attn / (self.lattice_temperature + 1e-6)
181
+
182
+ # Mask invalid positions
183
+ valid_2d = valid_mask.unsqueeze(2) * valid_mask.unsqueeze(1) # (B, T, T)
184
+ lattice_attn = lattice_attn * valid_2d
185
+
186
+ return lattice_attn
187
+
188
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None,
189
+ token_ids: torch.Tensor = None) -> torch.Tensor:
190
+ B, T, C = x.shape
191
+
192
+ # Standard attention
193
+ q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
194
+ k = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
195
+ v = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
196
+
197
+ # Standard QK^T attention
198
+ standard_attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
199
+
200
+ # Lattice attention (if enabled and token_ids provided)
201
+ if self.config.use_lattice_attention and token_ids is not None:
202
+ # Compute lattice-based attention
203
+ braille_values = self._get_braille_values(token_ids, None)
204
+ lattice_attn = self.compute_lattice_attention(braille_values)
205
+
206
+ # Expand for heads: (B, T, T) -> (B, num_heads, T, T)
207
+ lattice_attn = lattice_attn.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
208
+
209
+ # Blend standard and lattice attention
210
+ attn = (1 - self.lattice_weight) * standard_attn + self.lattice_weight * lattice_attn
211
+ else:
212
+ attn = standard_attn
213
+
214
+ # Apply causal mask
215
+ if mask is not None:
216
+ attn = attn.masked_fill(mask == 0, float('-inf'))
217
+
218
+ attn = F.softmax(attn, dim=-1)
219
+ attn = self.dropout(attn)
220
+
221
+ out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)
222
+ return self.out_proj(out)
223
+
224
+
225
+ # =============================================================================
226
+ # Lattice-Aware Embeddings
227
+ # =============================================================================
228
+
229
+ class LatticeEmbedding(nn.Module):
230
+ """
231
+ Token embeddings that respect Braille lattice structure.
232
+
233
+ Key insight: Initialize embeddings so that similar Braille cells
234
+ (low Hamming distance) have similar embeddings.
235
+
236
+ This provides an inductive bias that helps the model learn
237
+ patterns in the lattice structure.
238
+ """
239
+
240
+ def __init__(self, config: LatticeConfig):
241
+ super().__init__()
242
+ self.config = config
243
+
244
+ # Standard embedding
245
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
246
+
247
+ # Lattice structure embedding (for first 256 tokens = Braille cells)
248
+ self.lattice_embedding = nn.Embedding(256, config.hidden_size)
249
+
250
+ # Initialize lattice embeddings with structure
251
+ self._init_lattice_structure()
252
+
253
+ # Learnable blend weight
254
+ self.lattice_blend = nn.Parameter(torch.tensor(0.1))
255
+
256
+ def _init_lattice_structure(self):
257
+ """Initialize embeddings to reflect lattice structure."""
258
+ with torch.no_grad():
259
+ # Each Braille cell is an 8-bit vector
260
+ # Map each bit to a learned direction in embedding space
261
+
262
+ # Create 8 basis vectors (one per dot)
263
+ basis = torch.randn(8, self.config.hidden_size) * 0.1
264
+
265
+ for i in range(256):
266
+ # Get the bits of this cell
267
+ bits = [(i >> b) & 1 for b in range(8)]
268
+
269
+ # Embedding is sum of basis vectors for raised dots
270
+ emb = torch.zeros(self.config.hidden_size)
271
+ for b, bit in enumerate(bits):
272
+ if bit:
273
+ emb += basis[b]
274
+
275
+ self.lattice_embedding.weight[i] = emb
276
+
277
+ def forward(self, token_ids: torch.Tensor, training: bool = True) -> torch.Tensor:
278
+ # Standard embedding
279
+ std_emb = self.embedding(token_ids)
280
+
281
+ if self.config.use_lattice_embeddings:
282
+ # For tokens < 256, blend with lattice embedding
283
+ mask = (token_ids < 256).float().unsqueeze(-1)
284
+ safe_ids = token_ids.clamp(0, 255)
285
+ lat_emb = self.lattice_embedding(safe_ids)
286
+
287
+ # Blend: standard + lattice_blend * lattice (for Braille tokens only)
288
+ std_emb = std_emb + mask * self.lattice_blend * lat_emb
289
+
290
+ # Apply embedding dropout during training to prevent overfitting
291
+ if training and self.config.embedding_dropout > 0:
292
+ std_emb = F.dropout(std_emb, p=self.config.embedding_dropout, training=True)
293
+
294
+ return std_emb
295
+
296
+
297
+ # =============================================================================
298
+ # Morphological Regularization
299
+ # =============================================================================
300
+
301
+ class MorphologicalRegularizer(nn.Module):
302
+ """
303
+ Regularization based on morphological operations.
304
+
305
+ Encourages the model to learn representations that are
306
+ consistent under morphological transformations (erosion, dilation).
307
+
308
+ Loss = ||f(erode(x)) - erode(f(x))||² + ||f(dilate(x)) - dilate(f(x))||²
309
+
310
+ This is a form of equivariance regularization.
311
+ """
312
+
313
+ def __init__(self, config: LatticeConfig):
314
+ super().__init__()
315
+ self.config = config
316
+
317
+ # Get structuring element
318
+ se_map = {
319
+ 'six_dot': StructuringElement.six_dot(),
320
+ 'column_left': StructuringElement.column_left(),
321
+ 'column_right': StructuringElement.column_right(),
322
+ 'full': StructuringElement.full(),
323
+ }
324
+ self.se = se_map.get(config.structuring_element, StructuringElement.six_dot())
325
+ self.se_value = self.se.cell.value
326
+
327
+ def apply_morphology(self, token_ids: torch.Tensor,
328
+ op: str = 'erode') -> torch.Tensor:
329
+ """Apply morphological operation to token IDs."""
330
+ result = token_ids.clone()
331
+ mask = token_ids < 256 # Only apply to Braille tokens
332
+
333
+ if op == 'erode':
334
+ # Erosion: AND with structuring element
335
+ result[mask] = token_ids[mask] & self.se_value
336
+ elif op == 'dilate':
337
+ # Dilation: OR with structuring element
338
+ result[mask] = (token_ids[mask] | self.se_value) & 0xFF
339
+
340
+ return result
341
+
342
+ def compute_loss(self, embeddings: torch.Tensor,
343
+ token_ids: torch.Tensor,
344
+ embedding_layer: nn.Module) -> torch.Tensor:
345
+ """
346
+ Compute morphological equivariance loss.
347
+
348
+ We want: embed(morph(x)) ≈ morph_embed(embed(x))
349
+
350
+ Since we can't directly apply morphology to embeddings,
351
+ we use a proxy: embeddings of morphologically related tokens
352
+ should be similar.
353
+ """
354
+ if not self.config.use_morphological_regularization:
355
+ return torch.tensor(0.0, device=embeddings.device)
356
+
357
+ # Get eroded and dilated token IDs
358
+ eroded_ids = self.apply_morphology(token_ids, 'erode')
359
+ dilated_ids = self.apply_morphology(token_ids, 'dilate')
360
+
361
+ # Get embeddings
362
+ eroded_emb = embedding_layer(eroded_ids)
363
+ dilated_emb = embedding_layer(dilated_ids)
364
+
365
+ # Regularization: encourage margin between eroded and dilated distances
366
+ # Always-on version: penalize deviation from ideal ordering
367
+ # Ideal: dist_to_eroded < dist_to_original < dist_to_dilated
368
+
369
+ dist_to_eroded = F.mse_loss(embeddings, eroded_emb)
370
+ dist_to_dilated = F.mse_loss(embeddings, dilated_emb)
371
+
372
+ # Always-on: encourage margin (eroded should be closer than dilated)
373
+ # Use squared difference for smooth gradient
374
+ margin_loss = (dist_to_eroded - dist_to_dilated + 0.1).pow(2)
375
+
376
+ # Also add coherence loss: embeddings should be close to their morphological neighbors
377
+ coherence_loss = dist_to_eroded + dist_to_dilated
378
+
379
+ loss = margin_loss + 0.1 * coherence_loss
380
+
381
+ return loss * self.config.morphological_weight
382
+
383
+
384
+ # =============================================================================
385
+ # Modality Embedding
386
+ # =============================================================================
387
+
388
+ class ModalityEmbedding(nn.Module):
389
+ """
390
+ Embeddings for different modalities.
391
+
392
+ Adds a learned embedding based on the detected modality
393
+ of each token sequence.
394
+ """
395
+
396
+ # Modality header tokens (from v5)
397
+ MODALITY_HEADERS = {
398
+ 'TEXT': (0xFF, 0x01), # ⣿⠁
399
+ 'IMAGE': (0xFF, 0x03), # ⣿⠃
400
+ 'AUDIO': (0xFF, 0x07), # ⣿⠇
401
+ 'BINARY': (0xFF, 0x0F), # ⣿⠏
402
+ 'VIDEO': (0xFF, 0x17), # ⣿⠗
403
+ }
404
+
405
+ def __init__(self, config: LatticeConfig):
406
+ super().__init__()
407
+ self.embedding = nn.Embedding(config.num_modalities, config.hidden_size)
408
+
409
+ def detect_modality(self, token_ids: torch.Tensor) -> torch.Tensor:
410
+ """Detect modality from token sequence (simplified)."""
411
+ # For now, return 0 (TEXT) for all - would need tokenizer to decode
412
+ return torch.zeros(token_ids.shape[0], dtype=torch.long, device=token_ids.device)
413
+
414
+ def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
415
+ modality_ids = self.detect_modality(token_ids)
416
+ return self.embedding(modality_ids).unsqueeze(1) # (B, 1, H)
417
+
418
+
419
+ # =============================================================================
420
+ # Full Model
421
+ # =============================================================================
422
+
423
+ class FeedForward(nn.Module):
424
+ def __init__(self, config: LatticeConfig):
425
+ super().__init__()
426
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
427
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
428
+ self.dropout = nn.Dropout(config.dropout)
429
+
430
+ def forward(self, x):
431
+ return self.fc2(self.dropout(F.gelu(self.fc1(x))))
432
+
433
+
434
+ class LatticeTransformerBlock(nn.Module):
435
+ """Transformer block with lattice-aware attention."""
436
+
437
+ def __init__(self, config: LatticeConfig):
438
+ super().__init__()
439
+ self.ln1 = nn.LayerNorm(config.hidden_size)
440
+ self.attn = LatticeAttention(config)
441
+ self.ln2 = nn.LayerNorm(config.hidden_size)
442
+ self.ff = FeedForward(config)
443
+ self.dropout = nn.Dropout(config.dropout)
444
+
445
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None,
446
+ token_ids: torch.Tensor = None) -> torch.Tensor:
447
+ x = x + self.dropout(self.attn(self.ln1(x), mask, token_ids))
448
+ x = x + self.dropout(self.ff(self.ln2(x)))
449
+ return x
450
+
451
+
452
+ class Braille256LatticeModel(nn.Module):
453
+ """
454
+ braille256-v6: Lattice-Aware Multimodal Braille Model
455
+
456
+ Key innovations over v5:
457
+ 1. LatticeAttention: Hamming-based similarity in attention
458
+ 2. LatticeEmbedding: Structure-aware token embeddings
459
+ 3. MorphologicalRegularizer: Equivariance regularization
460
+ 4. ModalityEmbedding: Explicit modality awareness
461
+ """
462
+
463
+ def __init__(self, config: LatticeConfig):
464
+ super().__init__()
465
+ self.config = config
466
+
467
+ # Lattice-aware embeddings
468
+ self.token_embedding = LatticeEmbedding(config)
469
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
470
+ self.modality_embedding = ModalityEmbedding(config)
471
+ self.dropout = nn.Dropout(config.dropout)
472
+
473
+ # Transformer layers with lattice attention
474
+ self.layers = nn.ModuleList([
475
+ LatticeTransformerBlock(config) for _ in range(config.num_layers)
476
+ ])
477
+
478
+ self.ln_f = nn.LayerNorm(config.hidden_size)
479
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
480
+
481
+ # Morphological regularizer
482
+ self.morph_regularizer = MorphologicalRegularizer(config)
483
+
484
+ # Weight tying
485
+ self.lm_head.weight = self.token_embedding.embedding.weight
486
+
487
+ self.apply(self._init_weights)
488
+
489
+ # Log architecture
490
+ total_params = sum(p.numel() for p in self.parameters())
491
+ logger.info(f"Braille256-v6 Lattice Model: {total_params:,} parameters")
492
+ logger.info(f" Lattice attention: {config.use_lattice_attention}")
493
+ logger.info(f" Lattice embeddings: {config.use_lattice_embeddings}")
494
+ logger.info(f" Morphological regularization: {config.use_morphological_regularization}")
495
+
496
+ def _init_weights(self, module):
497
+ if isinstance(module, nn.Linear):
498
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
499
+ if module.bias is not None:
500
+ torch.nn.init.zeros_(module.bias)
501
+ elif isinstance(module, nn.Embedding):
502
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
503
+
504
+ def forward(self, input_ids: torch.Tensor,
505
+ labels: torch.Tensor = None) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
506
+ B, T = input_ids.shape
507
+
508
+ # Embeddings
509
+ positions = torch.arange(T, device=input_ids.device).unsqueeze(0)
510
+ tok_emb = self.token_embedding(input_ids, training=self.training)
511
+ pos_emb = self.position_embedding(positions)
512
+ mod_emb = self.modality_embedding(input_ids)
513
+
514
+ x = tok_emb + pos_emb + mod_emb
515
+ x = self.dropout(x)
516
+
517
+ # Causal mask
518
+ mask = torch.tril(torch.ones(T, T, device=input_ids.device)).unsqueeze(0).unsqueeze(0)
519
+
520
+ # Transformer layers
521
+ for layer in self.layers:
522
+ x = layer(x, mask, input_ids)
523
+
524
+ x = self.ln_f(x)
525
+ logits = self.lm_head(x)
526
+
527
+ # Compute losses
528
+ lm_loss = None
529
+ morph_loss = None
530
+
531
+ if labels is not None:
532
+ lm_loss = F.cross_entropy(
533
+ logits.view(-1, self.config.vocab_size),
534
+ labels.view(-1),
535
+ ignore_index=-100
536
+ )
537
+
538
+ # Morphological regularization
539
+ morph_loss = self.morph_regularizer.compute_loss(
540
+ tok_emb, input_ids, self.token_embedding
541
+ )
542
+
543
+ return logits, lm_loss, morph_loss
544
+
545
+ def generate(self, input_ids: torch.Tensor, max_length: int = 100,
546
+ temperature: float = 1.0, top_k: int = 50) -> torch.Tensor:
547
+ self.eval()
548
+ with torch.no_grad():
549
+ for _ in range(max_length):
550
+ if input_ids.shape[1] >= self.config.max_position_embeddings:
551
+ break
552
+
553
+ logits, _, _ = self(input_ids)
554
+ logits = logits[:, -1, :] / temperature
555
+
556
+ if top_k > 0:
557
+ v, _ = torch.topk(logits, top_k)
558
+ logits[logits < v[:, [-1]]] = float('-inf')
559
+
560
+ probs = F.softmax(logits, dim=-1)
561
+ next_token = torch.multinomial(probs, num_samples=1)
562
+ input_ids = torch.cat([input_ids, next_token], dim=1)
563
+
564
+ return input_ids
565
+
566
+
567
+ # =============================================================================
568
+ # Dataset (same as v5)
569
+ # =============================================================================
570
+
571
+ class MultimodalBrailleDataset(Dataset):
572
+ def __init__(self, corpus_path: str, tokenizer_path: str,
573
+ max_length: int = 512, max_tokens: int = 10_000_000):
574
+ self.max_length = max_length
575
+
576
+ self.sp = spm.SentencePieceProcessor()
577
+ self.sp.load(tokenizer_path)
578
+
579
+ logger.info(f"Loading corpus from {corpus_path}...")
580
+ with open(corpus_path, 'r', encoding='utf-8') as f:
581
+ text = f.read()
582
+
583
+ if len(text) > max_tokens * 3:
584
+ logger.info(f"Limiting corpus from {len(text):,} to ~{max_tokens:,} tokens worth")
585
+ text = text[:max_tokens * 3]
586
+
587
+ logger.info(f"Tokenizing {len(text):,} characters...")
588
+ self.tokens = self.sp.encode(text)
589
+ if len(self.tokens) > max_tokens:
590
+ self.tokens = self.tokens[:max_tokens]
591
+ logger.info(f"Got {len(self.tokens):,} tokens")
592
+
593
+ self.examples = []
594
+ stride = max_length // 2
595
+ for i in range(0, len(self.tokens) - max_length, stride):
596
+ self.examples.append(i)
597
+
598
+ logger.info(f"Created {len(self.examples):,} training examples")
599
+
600
+ def __len__(self):
601
+ return len(self.examples)
602
+
603
+ def __getitem__(self, idx):
604
+ start = self.examples[idx]
605
+ tokens = self.tokens[start:start + self.max_length + 1]
606
+
607
+ input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
608
+ labels = torch.tensor(tokens[1:], dtype=torch.long)
609
+
610
+ return input_ids, labels
611
+
612
+
613
+ # =============================================================================
614
+ # Haptic Evaluation
615
+ # =============================================================================
616
+
617
+ class HapticEvaluator:
618
+ """
619
+ Evaluate model outputs for haptic/tactile quality.
620
+
621
+ Metrics:
622
+ 1. Lattice coherence: How well outputs respect lattice structure
623
+ 2. Morphological stability: Consistency under erosion/dilation
624
+ 3. Modality preservation: Cross-modal semantic consistency
625
+ """
626
+
627
+ def __init__(self, config: LatticeConfig):
628
+ self.config = config
629
+ self.se = StructuringElement.six_dot()
630
+
631
+ def lattice_coherence(self, token_ids: torch.Tensor) -> float:
632
+ """
633
+ Measure how well token sequences respect lattice structure.
634
+
635
+ High coherence = adjacent tokens have low Hamming distance.
636
+ """
637
+ if token_ids.shape[-1] < 2:
638
+ return 1.0
639
+
640
+ total_dist = 0
641
+ count = 0
642
+
643
+ for i in range(token_ids.shape[-1] - 1):
644
+ t1 = token_ids[..., i].item() if token_ids[..., i].numel() == 1 else token_ids[0, i].item()
645
+ t2 = token_ids[..., i+1].item() if token_ids[..., i+1].numel() == 1 else token_ids[0, i+1].item()
646
+
647
+ if t1 < 256 and t2 < 256:
648
+ # Hamming distance
649
+ dist = bin(t1 ^ t2).count('1')
650
+ total_dist += dist
651
+ count += 1
652
+
653
+ if count == 0:
654
+ return 1.0
655
+
656
+ # Normalize: 0 = max coherence, 8 = min coherence
657
+ avg_dist = total_dist / count
658
+ return 1.0 - (avg_dist / 8.0)
659
+
660
+ def morphological_stability(self, token_ids: torch.Tensor) -> float:
661
+ """
662
+ Measure stability under morphological operations.
663
+
664
+ High stability = erosion and dilation don't change meaning drastically.
665
+ """
666
+ if token_ids.numel() == 0:
667
+ return 1.0
668
+
669
+ original = token_ids.clone()
670
+
671
+ # Apply erosion
672
+ eroded = original.clone()
673
+ mask = original < 256
674
+ eroded[mask] = original[mask] & self.se.cell.value
675
+
676
+ # Apply dilation
677
+ dilated = original.clone()
678
+ dilated[mask] = (original[mask] | self.se.cell.value) & 0xFF
679
+
680
+ # Measure how much changed
681
+ erode_change = (original[mask] != eroded[mask]).float().mean().item() if mask.any() else 0
682
+ dilate_change = (original[mask] != dilated[mask]).float().mean().item() if mask.any() else 0
683
+
684
+ # Stability = 1 - average change
685
+ return 1.0 - (erode_change + dilate_change) / 2
686
+
687
+ def evaluate(self, model: nn.Module, dataloader: DataLoader,
688
+ device: torch.device, num_samples: int = 100) -> Dict[str, float]:
689
+ """Run full haptic evaluation."""
690
+ model.eval()
691
+
692
+ coherence_scores = []
693
+ stability_scores = []
694
+
695
+ with torch.no_grad():
696
+ for i, (input_ids, _) in enumerate(dataloader):
697
+ if i >= num_samples:
698
+ break
699
+
700
+ input_ids = input_ids.to(device)
701
+
702
+ # Generate some tokens
703
+ generated = model.generate(input_ids[:, :10], max_length=50)
704
+
705
+ coherence_scores.append(self.lattice_coherence(generated))
706
+ stability_scores.append(self.morphological_stability(generated))
707
+
708
+ return {
709
+ 'lattice_coherence': np.mean(coherence_scores),
710
+ 'morphological_stability': np.mean(stability_scores),
711
+ 'haptic_score': np.mean(coherence_scores) * 0.5 + np.mean(stability_scores) * 0.5
712
+ }
713
+
714
+
715
+ # =============================================================================
716
+ # Training
717
+ # =============================================================================
718
+
719
+ def train(
720
+ corpus_path: str,
721
+ tokenizer_path: str,
722
+ output_dir: str,
723
+ max_steps: int = 10000,
724
+ batch_size: int = 16,
725
+ learning_rate: float = 3e-4,
726
+ gradient_accumulation: int = 2,
727
+ save_steps: int = 1000,
728
+ eval_steps: int = 500,
729
+ use_lattice_attention: bool = True,
730
+ use_lattice_embeddings: bool = True,
731
+ use_morphological_regularization: bool = True,
732
+ ):
733
+ """Train the lattice-aware model."""
734
+
735
+ os.makedirs(output_dir, exist_ok=True)
736
+
737
+ # Device
738
+ if torch.backends.mps.is_available():
739
+ device = torch.device("mps")
740
+ elif torch.cuda.is_available():
741
+ device = torch.device("cuda")
742
+ else:
743
+ device = torch.device("cpu")
744
+ logger.info(f"Using device: {device}")
745
+
746
+ # Load tokenizer
747
+ sp = spm.SentencePieceProcessor()
748
+ sp.load(tokenizer_path)
749
+ vocab_size = sp.get_piece_size()
750
+
751
+ # Config
752
+ config = LatticeConfig(
753
+ vocab_size=vocab_size,
754
+ use_lattice_attention=use_lattice_attention,
755
+ use_lattice_embeddings=use_lattice_embeddings,
756
+ use_morphological_regularization=use_morphological_regularization,
757
+ )
758
+
759
+ # Save config
760
+ with open(os.path.join(output_dir, "config.json"), 'w') as f:
761
+ json.dump(config.to_dict(), f, indent=2)
762
+
763
+ # Model
764
+ model = Braille256LatticeModel(config)
765
+ model.to(device)
766
+
767
+ # Dataset
768
+ dataset = MultimodalBrailleDataset(
769
+ corpus_path, tokenizer_path,
770
+ max_length=256, max_tokens=2_000_000
771
+ )
772
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
773
+
774
+ # Evaluator
775
+ evaluator = HapticEvaluator(config)
776
+
777
+ # Optimizer with increased weight decay to preserve lattice structure
778
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.05)
779
+
780
+ # LR scheduler
781
+ def lr_lambda(step):
782
+ warmup_steps = 500
783
+ if step < warmup_steps:
784
+ return step / warmup_steps
785
+ decay_steps = max_steps - warmup_steps
786
+ progress = (step - warmup_steps) / decay_steps
787
+ return 0.5 * (1 + math.cos(math.pi * progress))
788
+
789
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
790
+
791
+ # Mixed precision - only for CUDA, MPS AMP has issues with custom ops
792
+ use_amp = device.type == 'cuda'
793
+ scaler = torch.amp.GradScaler('cuda') if use_amp else None
794
+
795
+ # torch.compile disabled for MPS - causes slow compilation overhead
796
+ # Enable only for CUDA
797
+ compiled = False
798
+ if device.type == 'cuda':
799
+ try:
800
+ model = torch.compile(model, mode="reduce-overhead")
801
+ compiled = True
802
+ except Exception as e:
803
+ logger.warning(f"torch.compile not available: {e}")
804
+
805
+ # Training loop
806
+ print("\n" + "=" * 70)
807
+ print("⣿ braille256-v6: Lattice-Aware Training ⣿")
808
+ print("=" * 70)
809
+ print(f" Max steps: {max_steps}")
810
+ print(f" Batch size: {batch_size} x {gradient_accumulation} = {batch_size * gradient_accumulation}")
811
+ print(f" Learning rate: {learning_rate}")
812
+ print(f" Lattice attention: {use_lattice_attention}")
813
+ print(f" Lattice embeddings: {use_lattice_embeddings}")
814
+ print(f" Morphological regularization: {use_morphological_regularization}")
815
+ print(f" Mixed precision (AMP): {use_amp}")
816
+ print(f" torch.compile: {compiled}")
817
+ print(f" Output: {output_dir}")
818
+ print("=" * 70 + "\n")
819
+
820
+ model.train()
821
+ step = 0
822
+ data_iter = iter(dataloader)
823
+ best_haptic_score = 0
824
+
825
+ pbar = tqdm(total=max_steps, desc="Training")
826
+
827
+ training_log = []
828
+
829
+ while step < max_steps:
830
+ optimizer.zero_grad()
831
+ total_lm_loss = 0
832
+ total_morph_loss = 0
833
+
834
+ # Staged morphological regularization: high early, decay later
835
+ # This locks in geometry early while allowing expressivity later
836
+ if step < 1500:
837
+ morph_weight_scale = 1.0 # Full strength: 0.05
838
+ elif step < 4000:
839
+ morph_weight_scale = 0.4 # Medium: 0.02
840
+ else:
841
+ morph_weight_scale = 0.1 # Low: 0.005
842
+
843
+ # Update the model's morph weight dynamically
844
+ model.morph_regularizer.config.morphological_weight = 0.05 * morph_weight_scale
845
+
846
+ for _ in range(gradient_accumulation):
847
+ try:
848
+ input_ids, labels = next(data_iter)
849
+ except StopIteration:
850
+ data_iter = iter(dataloader)
851
+ input_ids, labels = next(data_iter)
852
+
853
+ input_ids = input_ids.to(device)
854
+ labels = labels.to(device)
855
+
856
+ # Mixed precision forward pass
857
+ if use_amp:
858
+ with torch.amp.autocast(device.type):
859
+ _, lm_loss, morph_loss = model(input_ids, labels)
860
+ loss = lm_loss + morph_loss
861
+ loss = loss / gradient_accumulation
862
+ scaler.scale(loss).backward()
863
+ else:
864
+ _, lm_loss, morph_loss = model(input_ids, labels)
865
+ loss = lm_loss + morph_loss
866
+ loss = loss / gradient_accumulation
867
+ loss.backward()
868
+
869
+ total_lm_loss += lm_loss.item() / gradient_accumulation
870
+ total_morph_loss += morph_loss.item() / gradient_accumulation if morph_loss else 0
871
+
872
+ if use_amp:
873
+ scaler.unscale_(optimizer)
874
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
875
+ scaler.step(optimizer)
876
+ scaler.update()
877
+ else:
878
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
879
+ optimizer.step()
880
+ scheduler.step()
881
+
882
+ step += 1
883
+
884
+ pbar.set_postfix(
885
+ lm_loss=f"{total_lm_loss:.4f}",
886
+ morph=f"{total_morph_loss:.4f}",
887
+ lr=f"{scheduler.get_last_lr()[0]:.2e}"
888
+ )
889
+ pbar.update(1)
890
+
891
+ # Log
892
+ if step % 100 == 0:
893
+ training_log.append({
894
+ 'step': step,
895
+ 'lm_loss': total_lm_loss,
896
+ 'morph_loss': total_morph_loss,
897
+ 'lr': scheduler.get_last_lr()[0]
898
+ })
899
+
900
+ # Evaluate
901
+ if step % eval_steps == 0:
902
+ eval_results = evaluator.evaluate(model, dataloader, device, num_samples=20)
903
+ logger.info(f"\nStep {step} Haptic Eval: {eval_results}")
904
+
905
+ if eval_results['haptic_score'] > best_haptic_score:
906
+ best_haptic_score = eval_results['haptic_score']
907
+ # Save best model
908
+ best_dir = os.path.join(output_dir, "best")
909
+ os.makedirs(best_dir, exist_ok=True)
910
+ torch.save(model.state_dict(), os.path.join(best_dir, "pytorch_model.bin"))
911
+ logger.info(f"New best haptic score: {best_haptic_score:.4f}")
912
+
913
+ model.train()
914
+
915
+ # Save checkpoint
916
+ if step % save_steps == 0:
917
+ checkpoint_dir = os.path.join(output_dir, f"checkpoint-{step}")
918
+ os.makedirs(checkpoint_dir, exist_ok=True)
919
+ torch.save(model.state_dict(), os.path.join(checkpoint_dir, "pytorch_model.bin"))
920
+ with open(os.path.join(checkpoint_dir, "config.json"), 'w') as f:
921
+ json.dump(config.to_dict(), f, indent=2)
922
+ logger.info(f"Saved checkpoint at step {step}")
923
+
924
+ pbar.close()
925
+
926
+ # Save final model
927
+ print("\n" + "=" * 70)
928
+ print("Saving Final Model")
929
+ print("=" * 70)
930
+
931
+ final_dir = os.path.join(output_dir, "final")
932
+ os.makedirs(final_dir, exist_ok=True)
933
+
934
+ torch.save(model.state_dict(), os.path.join(final_dir, "pytorch_model.bin"))
935
+ with open(os.path.join(final_dir, "config.json"), 'w') as f:
936
+ json.dump(config.to_dict(), f, indent=2)
937
+
938
+ # Save training log
939
+ with open(os.path.join(output_dir, "training_log.json"), 'w') as f:
940
+ json.dump(training_log, f, indent=2)
941
+
942
+ # Copy tokenizer
943
+ import shutil
944
+ shutil.copy(tokenizer_path, os.path.join(final_dir, "tokenizer.model"))
945
+
946
+ # Final evaluation
947
+ final_eval = evaluator.evaluate(model, dataloader, device, num_samples=50)
948
+ print(f"\nFinal Haptic Evaluation:")
949
+ print(f" Lattice Coherence: {final_eval['lattice_coherence']:.4f}")
950
+ print(f" Morphological Stability: {final_eval['morphological_stability']:.4f}")
951
+ print(f" Haptic Score: {final_eval['haptic_score']:.4f}")
952
+
953
+ with open(os.path.join(output_dir, "final_eval.json"), 'w') as f:
954
+ json.dump(final_eval, f, indent=2)
955
+
956
+ print(f"\nModel saved to: {final_dir}")
957
+ print("\n" + "=" * 70)
958
+ print("⣿ Training Complete! ⣿")
959
+ print("=" * 70)
960
+
961
+
962
+ def main():
963
+ parser = argparse.ArgumentParser(description="Train braille256-v6 lattice-aware model")
964
+ parser.add_argument("--corpus", default="corpus/braille_multimodal_corpus.txt")
965
+ parser.add_argument("--tokenizer", default="tokenizers/braille_8dot_32k/braille_8dot_32k.model")
966
+ parser.add_argument("--output", default="models/braille256_v6_lattice")
967
+ parser.add_argument("--steps", type=int, default=10000)
968
+ parser.add_argument("--batch-size", type=int, default=16)
969
+ parser.add_argument("--lr", type=float, default=3e-4)
970
+ parser.add_argument("--no-lattice-attention", action="store_true")
971
+ parser.add_argument("--no-lattice-embeddings", action="store_true")
972
+ parser.add_argument("--no-morph-regularization", action="store_true")
973
+
974
+ args = parser.parse_args()
975
+
976
+ train(
977
+ corpus_path=args.corpus,
978
+ tokenizer_path=args.tokenizer,
979
+ output_dir=args.output,
980
+ max_steps=args.steps,
981
+ batch_size=args.batch_size,
982
+ learning_rate=args.lr,
983
+ use_lattice_attention=not args.no_lattice_attention,
984
+ use_lattice_embeddings=not args.no_lattice_embeddings,
985
+ use_morphological_regularization=not args.no_morph_regularization,
986
+ )
987
+
988
+
989
+ if __name__ == "__main__":
990
+ main()
training_log.json ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "step": 100,
4
+ "lm_loss": 6.993148565292358,
5
+ "morph_loss": 0.0004997336654923856,
6
+ "lr": 5.9999999999999995e-05
7
+ },
8
+ {
9
+ "step": 200,
10
+ "lm_loss": 3.8235212564468384,
11
+ "morph_loss": 0.0004995590425096452,
12
+ "lr": 0.00011999999999999999
13
+ },
14
+ {
15
+ "step": 300,
16
+ "lm_loss": 2.6851329803466797,
17
+ "morph_loss": 0.0004995536583010107,
18
+ "lr": 0.00017999999999999998
19
+ },
20
+ {
21
+ "step": 400,
22
+ "lm_loss": 2.7323516607284546,
23
+ "morph_loss": 0.0004988558066543192,
24
+ "lr": 0.00023999999999999998
25
+ },
26
+ {
27
+ "step": 500,
28
+ "lm_loss": 2.6041159629821777,
29
+ "morph_loss": 0.000500149471918121,
30
+ "lr": 0.0003
31
+ },
32
+ {
33
+ "step": 600,
34
+ "lm_loss": 2.7856509685516357,
35
+ "morph_loss": 0.0004991319729015231,
36
+ "lr": 0.0002999179886011389
37
+ },
38
+ {
39
+ "step": 700,
40
+ "lm_loss": 1.9858170747756958,
41
+ "morph_loss": 0.0004987868887837976,
42
+ "lr": 0.00029967204408281613
43
+ },
44
+ {
45
+ "step": 800,
46
+ "lm_loss": 2.2885484099388123,
47
+ "morph_loss": 0.0004994409391656518,
48
+ "lr": 0.0002992624353817517
49
+ },
50
+ {
51
+ "step": 900,
52
+ "lm_loss": 1.7579542398452759,
53
+ "morph_loss": 0.000499250425491482,
54
+ "lr": 0.00029868961039904624
55
+ },
56
+ {
57
+ "step": 1000,
58
+ "lm_loss": 2.356551766395569,
59
+ "morph_loss": 0.000498554261866957,
60
+ "lr": 0.00029795419551040833
61
+ },
62
+ {
63
+ "step": 1100,
64
+ "lm_loss": 2.2041295170783997,
65
+ "morph_loss": 0.0004986616258975118,
66
+ "lr": 0.0002970569948812214
67
+ },
68
+ {
69
+ "step": 1200,
70
+ "lm_loss": 1.987478256225586,
71
+ "morph_loss": 0.0004990812740288675,
72
+ "lr": 0.0002959989895872009
73
+ },
74
+ {
75
+ "step": 1300,
76
+ "lm_loss": 1.7949504256248474,
77
+ "morph_loss": 0.0004993200418539345,
78
+ "lr": 0.0002947813365416023
79
+ },
80
+ {
81
+ "step": 1400,
82
+ "lm_loss": 2.577250361442566,
83
+ "morph_loss": 0.0004996150964871049,
84
+ "lr": 0.0002934053672301536
85
+ },
86
+ {
87
+ "step": 1500,
88
+ "lm_loss": 2.1883418560028076,
89
+ "morph_loss": 0.000496969121741131,
90
+ "lr": 0.00029187258625509513
91
+ },
92
+ {
93
+ "step": 1600,
94
+ "lm_loss": 1.70472252368927,
95
+ "morph_loss": 0.0001988508302019909,
96
+ "lr": 0.0002901846696899191
97
+ },
98
+ {
99
+ "step": 1700,
100
+ "lm_loss": 2.3121373057365417,
101
+ "morph_loss": 0.0001993859259528108,
102
+ "lr": 0.0002883434632466077
103
+ },
104
+ {
105
+ "step": 1800,
106
+ "lm_loss": 2.0045361518859863,
107
+ "morph_loss": 0.00019962265650974587,
108
+ "lr": 0.00028635098025737434
109
+ },
110
+ {
111
+ "step": 1900,
112
+ "lm_loss": 1.9210466742515564,
113
+ "morph_loss": 0.0001999259038711898,
114
+ "lr": 0.0002842093994731145
115
+ },
116
+ {
117
+ "step": 2000,
118
+ "lm_loss": 2.045823335647583,
119
+ "morph_loss": 0.00019963263912359253,
120
+ "lr": 0.00028192106268097334
121
+ },
122
+ {
123
+ "step": 2100,
124
+ "lm_loss": 2.363018274307251,
125
+ "morph_loss": 0.00020041707466589287,
126
+ "lr": 0.0002794884721436361
127
+ },
128
+ {
129
+ "step": 2200,
130
+ "lm_loss": 2.146875023841858,
131
+ "morph_loss": 0.00019886076188413426,
132
+ "lr": 0.0002769142878631403
133
+ },
134
+ {
135
+ "step": 2300,
136
+ "lm_loss": 1.7959808111190796,
137
+ "morph_loss": 0.0001991742683458142,
138
+ "lr": 0.000274201324672203
139
+ },
140
+ {
141
+ "step": 2400,
142
+ "lm_loss": 2.1792458295822144,
143
+ "morph_loss": 0.00020018102077301592,
144
+ "lr": 0.0002713525491562421
145
+ },
146
+ {
147
+ "step": 2500,
148
+ "lm_loss": 2.177172005176544,
149
+ "morph_loss": 0.0001997477374970913,
150
+ "lr": 0.00026837107640945905
151
+ },
152
+ {
153
+ "step": 2600,
154
+ "lm_loss": 1.8140788674354553,
155
+ "morph_loss": 0.00019913741562049836,
156
+ "lr": 0.00026526016662852886
157
+ },
158
+ {
159
+ "step": 2700,
160
+ "lm_loss": 2.1109378337860107,
161
+ "morph_loss": 0.00019985359540442005,
162
+ "lr": 0.0002620232215476231
163
+ },
164
+ {
165
+ "step": 2800,
166
+ "lm_loss": 1.8451208472251892,
167
+ "morph_loss": 0.0002002850960707292,
168
+ "lr": 0.00025866378071866334
169
+ },
170
+ {
171
+ "step": 2900,
172
+ "lm_loss": 1.4895858764648438,
173
+ "morph_loss": 0.00019962178339483216,
174
+ "lr": 0.00025518551764087326
175
+ },
176
+ {
177
+ "step": 3000,
178
+ "lm_loss": 1.5707910060882568,
179
+ "morph_loss": 0.00019960849022027105,
180
+ "lr": 0.00025159223574386114
181
+ },
182
+ {
183
+ "step": 3100,
184
+ "lm_loss": 1.5902798175811768,
185
+ "morph_loss": 0.0002000557360588573,
186
+ "lr": 0.00024788786422862526
187
+ },
188
+ {
189
+ "step": 3200,
190
+ "lm_loss": 1.8854023218154907,
191
+ "morph_loss": 0.00019861374312313274,
192
+ "lr": 0.00024407645377103054
193
+ },
194
+ {
195
+ "step": 3300,
196
+ "lm_loss": 1.6806678175926208,
197
+ "morph_loss": 0.00020012820459669456,
198
+ "lr": 0.00024016217209245374
199
+ },
200
+ {
201
+ "step": 3400,
202
+ "lm_loss": 1.827455759048462,
203
+ "morph_loss": 0.00020046472491230816,
204
+ "lr": 0.0002361492994024415
205
+ },
206
+ {
207
+ "step": 3500,
208
+ "lm_loss": 1.9594002962112427,
209
+ "morph_loss": 0.00019966769468737766,
210
+ "lr": 0.00023204222371836402
211
+ },
212
+ {
213
+ "step": 3600,
214
+ "lm_loss": 1.359905481338501,
215
+ "morph_loss": 0.00019929300469812006,
216
+ "lr": 0.00022784543606718227
217
+ },
218
+ {
219
+ "step": 3700,
220
+ "lm_loss": 1.9533899426460266,
221
+ "morph_loss": 0.00020049385057063773,
222
+ "lr": 0.0002235635255745762
223
+ },
224
+ {
225
+ "step": 3800,
226
+ "lm_loss": 1.4309832453727722,
227
+ "morph_loss": 0.00019867864466505125,
228
+ "lr": 0.00021920117444680317
229
+ },
230
+ {
231
+ "step": 3900,
232
+ "lm_loss": 1.2603670358657837,
233
+ "morph_loss": 0.00020007100101793185,
234
+ "lr": 0.0002147631528507739
235
+ },
236
+ {
237
+ "step": 4000,
238
+ "lm_loss": 1.359400749206543,
239
+ "morph_loss": 0.00019884618814103305,
240
+ "lr": 0.0002102543136979454
241
+ },
242
+ {
243
+ "step": 4100,
244
+ "lm_loss": 1.4845112562179565,
245
+ "morph_loss": 4.9905273044714704e-05,
246
+ "lr": 0.0002056795873377331
247
+ },
248
+ {
249
+ "step": 4200,
250
+ "lm_loss": 1.3849643468856812,
251
+ "morph_loss": 4.9970141844823956e-05,
252
+ "lr": 0.00020104397616624645
253
+ },
254
+ {
255
+ "step": 4300,
256
+ "lm_loss": 1.293235957622528,
257
+ "morph_loss": 4.9802090870798565e-05,
258
+ "lr": 0.0001963525491562421
259
+ },
260
+ {
261
+ "step": 4400,
262
+ "lm_loss": 1.8045696020126343,
263
+ "morph_loss": 5.000879900762811e-05,
264
+ "lr": 0.00019161043631427666
265
+ },
266
+ {
267
+ "step": 4500,
268
+ "lm_loss": 1.472623884677887,
269
+ "morph_loss": 4.9775953812059015e-05,
270
+ "lr": 0.00018682282307111987
271
+ },
272
+ {
273
+ "step": 4600,
274
+ "lm_loss": 1.4265462756156921,
275
+ "morph_loss": 4.9740716349333525e-05,
276
+ "lr": 0.00018199494461156203
277
+ },
278
+ {
279
+ "step": 4700,
280
+ "lm_loss": 1.3741839528083801,
281
+ "morph_loss": 5.005610182706732e-05,
282
+ "lr": 0.00017713208014981648
283
+ },
284
+ {
285
+ "step": 4800,
286
+ "lm_loss": 1.5047972798347473,
287
+ "morph_loss": 4.9890166337718256e-05,
288
+ "lr": 0.00017223954715677627
289
+ },
290
+ {
291
+ "step": 4900,
292
+ "lm_loss": 1.9711642265319824,
293
+ "morph_loss": 5.0176731747342274e-05,
294
+ "lr": 0.00016732269554543794
295
+ },
296
+ {
297
+ "step": 5000,
298
+ "lm_loss": 1.3522289395332336,
299
+ "morph_loss": 5.020072603656445e-05,
300
+ "lr": 0.00016238690182084986
301
+ },
302
+ {
303
+ "step": 5100,
304
+ "lm_loss": 1.7605299353599548,
305
+ "morph_loss": 4.991166679246817e-05,
306
+ "lr": 0.00015743756320098332
307
+ },
308
+ {
309
+ "step": 5200,
310
+ "lm_loss": 1.4882904291152954,
311
+ "morph_loss": 4.9983715143753216e-05,
312
+ "lr": 0.00015248009171495378
313
+ },
314
+ {
315
+ "step": 5300,
316
+ "lm_loss": 1.6343830823898315,
317
+ "morph_loss": 5.002636680728756e-05,
318
+ "lr": 0.00014751990828504622
319
+ },
320
+ {
321
+ "step": 5400,
322
+ "lm_loss": 1.1790239810943604,
323
+ "morph_loss": 4.997515679860953e-05,
324
+ "lr": 0.00014256243679901663
325
+ },
326
+ {
327
+ "step": 5500,
328
+ "lm_loss": 1.2620559334754944,
329
+ "morph_loss": 4.979184268449899e-05,
330
+ "lr": 0.00013761309817915014
331
+ },
332
+ {
333
+ "step": 5600,
334
+ "lm_loss": 1.2926940321922302,
335
+ "morph_loss": 4.9720953029464e-05,
336
+ "lr": 0.00013267730445456208
337
+ },
338
+ {
339
+ "step": 5700,
340
+ "lm_loss": 1.6810715198516846,
341
+ "morph_loss": 5.0059263230650686e-05,
342
+ "lr": 0.00012776045284322368
343
+ },
344
+ {
345
+ "step": 5800,
346
+ "lm_loss": 1.3960903882980347,
347
+ "morph_loss": 5.0068707423633896e-05,
348
+ "lr": 0.00012286791985018355
349
+ },
350
+ {
351
+ "step": 5900,
352
+ "lm_loss": 1.5417255759239197,
353
+ "morph_loss": 5.0141115934820846e-05,
354
+ "lr": 0.00011800505538843798
355
+ },
356
+ {
357
+ "step": 6000,
358
+ "lm_loss": 1.3895499110221863,
359
+ "morph_loss": 4.9999320253846236e-05,
360
+ "lr": 0.00011317717692888012
361
+ },
362
+ {
363
+ "step": 6100,
364
+ "lm_loss": 1.578350841999054,
365
+ "morph_loss": 4.975642514182255e-05,
366
+ "lr": 0.00010838956368572334
367
+ },
368
+ {
369
+ "step": 6200,
370
+ "lm_loss": 0.8763712048530579,
371
+ "morph_loss": 4.9885256885318086e-05,
372
+ "lr": 0.0001036474508437579
373
+ },
374
+ {
375
+ "step": 6300,
376
+ "lm_loss": 1.3805139064788818,
377
+ "morph_loss": 4.985421219316777e-05,
378
+ "lr": 9.895602383375353e-05
379
+ },
380
+ {
381
+ "step": 6400,
382
+ "lm_loss": 1.642943263053894,
383
+ "morph_loss": 4.988547880202532e-05,
384
+ "lr": 9.432041266226686e-05
385
+ },
386
+ {
387
+ "step": 6500,
388
+ "lm_loss": 1.2295689284801483,
389
+ "morph_loss": 4.976921445631888e-05,
390
+ "lr": 8.97456863020546e-05
391
+ },
392
+ {
393
+ "step": 6600,
394
+ "lm_loss": 0.9539550840854645,
395
+ "morph_loss": 4.960754813509993e-05,
396
+ "lr": 8.523684714922608e-05
397
+ },
398
+ {
399
+ "step": 6700,
400
+ "lm_loss": 1.4480910301208496,
401
+ "morph_loss": 4.9842370572150685e-05,
402
+ "lr": 8.079882555319683e-05
403
+ },
404
+ {
405
+ "step": 6800,
406
+ "lm_loss": 1.1316336393356323,
407
+ "morph_loss": 4.944742067891639e-05,
408
+ "lr": 7.643647442542382e-05
409
+ },
410
+ {
411
+ "step": 6900,
412
+ "lm_loss": 1.2974263429641724,
413
+ "morph_loss": 4.9362375648343004e-05,
414
+ "lr": 7.215456393281776e-05
415
+ },
416
+ {
417
+ "step": 7000,
418
+ "lm_loss": 1.8624339699745178,
419
+ "morph_loss": 4.9819502237369306e-05,
420
+ "lr": 6.795777628163599e-05
421
+ },
422
+ {
423
+ "step": 7100,
424
+ "lm_loss": 1.2204494774341583,
425
+ "morph_loss": 5.013305417378433e-05,
426
+ "lr": 6.385070059755846e-05
427
+ },
428
+ {
429
+ "step": 7200,
430
+ "lm_loss": 1.5136797428131104,
431
+ "morph_loss": 4.9816295359050855e-05,
432
+ "lr": 5.983782790754623e-05
433
+ },
434
+ {
435
+ "step": 7300,
436
+ "lm_loss": 1.3666653037071228,
437
+ "morph_loss": 4.99475918331882e-05,
438
+ "lr": 5.592354622896944e-05
439
+ },
440
+ {
441
+ "step": 7400,
442
+ "lm_loss": 0.8389511108398438,
443
+ "morph_loss": 4.9609083362156525e-05,
444
+ "lr": 5.211213577137469e-05
445
+ },
446
+ {
447
+ "step": 7500,
448
+ "lm_loss": 1.1075031757354736,
449
+ "morph_loss": 4.941183760820422e-05,
450
+ "lr": 4.840776425613885e-05
451
+ },
452
+ {
453
+ "step": 7600,
454
+ "lm_loss": 1.2579197883605957,
455
+ "morph_loss": 4.980366429663263e-05,
456
+ "lr": 4.481448235912671e-05
457
+ },
458
+ {
459
+ "step": 7700,
460
+ "lm_loss": 1.0307890474796295,
461
+ "morph_loss": 4.9867769121192396e-05,
462
+ "lr": 4.133621928133665e-05
463
+ },
464
+ {
465
+ "step": 7800,
466
+ "lm_loss": 1.0696255564689636,
467
+ "morph_loss": 5.0295926484977826e-05,
468
+ "lr": 3.797677845237696e-05
469
+ },
470
+ {
471
+ "step": 7900,
472
+ "lm_loss": 1.3926631212234497,
473
+ "morph_loss": 5.014805537939537e-05,
474
+ "lr": 3.473983337147118e-05
475
+ },
476
+ {
477
+ "step": 8000,
478
+ "lm_loss": 1.5005779266357422,
479
+ "morph_loss": 5.018570300308056e-05,
480
+ "lr": 3.162892359054098e-05
481
+ },
482
+ {
483
+ "step": 8100,
484
+ "lm_loss": 1.7105327248573303,
485
+ "morph_loss": 5.010495260648895e-05,
486
+ "lr": 2.8647450843757897e-05
487
+ },
488
+ {
489
+ "step": 8200,
490
+ "lm_loss": 1.229815423488617,
491
+ "morph_loss": 4.9758424211177044e-05,
492
+ "lr": 2.5798675327796993e-05
493
+ },
494
+ {
495
+ "step": 8300,
496
+ "lm_loss": 1.1335912346839905,
497
+ "morph_loss": 4.941884253639728e-05,
498
+ "lr": 2.3085712136859668e-05
499
+ },
500
+ {
501
+ "step": 8400,
502
+ "lm_loss": 1.0056449174880981,
503
+ "morph_loss": 4.942774467053823e-05,
504
+ "lr": 2.0511527856363895e-05
505
+ },
506
+ {
507
+ "step": 8500,
508
+ "lm_loss": 1.6661915183067322,
509
+ "morph_loss": 4.997232463210821e-05,
510
+ "lr": 1.8078937319026654e-05
511
+ },
512
+ {
513
+ "step": 8600,
514
+ "lm_loss": 1.1359021067619324,
515
+ "morph_loss": 4.9868809583131224e-05,
516
+ "lr": 1.579060052688548e-05
517
+ },
518
+ {
519
+ "step": 8700,
520
+ "lm_loss": 1.434115707874298,
521
+ "morph_loss": 4.993028596800286e-05,
522
+ "lr": 1.3649019742625623e-05
523
+ },
524
+ {
525
+ "step": 8800,
526
+ "lm_loss": 1.2552986145019531,
527
+ "morph_loss": 4.9796975872595794e-05,
528
+ "lr": 1.1656536753392287e-05
529
+ },
530
+ {
531
+ "step": 8900,
532
+ "lm_loss": 1.217362403869629,
533
+ "morph_loss": 4.991148489352781e-05,
534
+ "lr": 9.815330310080887e-06
535
+ },
536
+ {
537
+ "step": 9000,
538
+ "lm_loss": 1.7755168080329895,
539
+ "morph_loss": 5.013133522879798e-05,
540
+ "lr": 8.127413744904804e-06
541
+ },
542
+ {
543
+ "step": 9100,
544
+ "lm_loss": 1.3130499720573425,
545
+ "morph_loss": 5.007252184441313e-05,
546
+ "lr": 6.594632769846353e-06
547
+ },
548
+ {
549
+ "step": 9200,
550
+ "lm_loss": 1.1731150150299072,
551
+ "morph_loss": 5.0212831411045045e-05,
552
+ "lr": 5.218663458397715e-06
553
+ },
554
+ {
555
+ "step": 9300,
556
+ "lm_loss": 0.9502497613430023,
557
+ "morph_loss": 4.9593836592976004e-05,
558
+ "lr": 4.001010412799138e-06
559
+ },
560
+ {
561
+ "step": 9400,
562
+ "lm_loss": 1.3233891725540161,
563
+ "morph_loss": 4.999847624276299e-05,
564
+ "lr": 2.9430051187785962e-06
565
+ },
566
+ {
567
+ "step": 9500,
568
+ "lm_loss": 1.3283841013908386,
569
+ "morph_loss": 5.0088236093870364e-05,
570
+ "lr": 2.0458044895916513e-06
571
+ },
572
+ {
573
+ "step": 9600,
574
+ "lm_loss": 1.2733866572380066,
575
+ "morph_loss": 4.998848271497991e-05,
576
+ "lr": 1.3103896009537207e-06
577
+ },
578
+ {
579
+ "step": 9700,
580
+ "lm_loss": 1.1967694163322449,
581
+ "morph_loss": 5.023612902732566e-05,
582
+ "lr": 7.375646182482875e-07
583
+ },
584
+ {
585
+ "step": 9800,
586
+ "lm_loss": 1.0563868880271912,
587
+ "morph_loss": 4.972490751242731e-05,
588
+ "lr": 3.2795591718381975e-07
589
+ },
590
+ {
591
+ "step": 9900,
592
+ "lm_loss": 1.7780798077583313,
593
+ "morph_loss": 4.9859330829349346e-05,
594
+ "lr": 8.201139886109264e-08
595
+ },
596
+ {
597
+ "step": 10000,
598
+ "lm_loss": 1.231259286403656,
599
+ "morph_loss": 4.956562952429522e-05,
600
+ "lr": 0.0
601
+ }
602
+ ]