Upload braille256-v6: Lattice-aware multimodal Braille model
Browse files- README.md +186 -0
- braille_lattice_theory.py +1010 -0
- config.json +18 -0
- final_eval.json +5 -0
- pytorch_model.bin +3 -0
- tokenizer.model +3 -0
- train_lattice_v6.py +990 -0
- training_log.json +602 -0
README.md
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# braille256-v6: Lattice-Aware Multimodal Braille Model
|
| 2 |
+
|
| 3 |
+
**The first LLM with explicit dot-lattice structure in its architecture.**
|
| 4 |
+
|
| 5 |
+
## Model Description
|
| 6 |
+
|
| 7 |
+
braille256-v6 builds on the multimodal foundation of v5, integrating formal lattice theory into the training pipeline. This is not just a Braille-native model—it's a **lattice-native** model that understands the mathematical structure of Braille at the architectural level.
|
| 8 |
+
|
| 9 |
+
### Key Innovations
|
| 10 |
+
|
| 11 |
+
| Feature | Description |
|
| 12 |
+
|---------|-------------|
|
| 13 |
+
| **Lattice Attention** | Attention scores incorporate Hamming-based similarity on Braille cells |
|
| 14 |
+
| **Lattice Embeddings** | Token embeddings initialized to respect Boolean lattice structure |
|
| 15 |
+
| **Morphological Regularization** | Training loss includes equivariance under erosion/dilation |
|
| 16 |
+
| **Haptic Evaluation** | New metrics for tactile quality of outputs |
|
| 17 |
+
|
| 18 |
+
## Architecture
|
| 19 |
+
|
| 20 |
+
```
|
| 21 |
+
Parameters: ~12M
|
| 22 |
+
Layers: 4
|
| 23 |
+
Heads: 4
|
| 24 |
+
Hidden: 256
|
| 25 |
+
Vocab: 32,000 (SentencePiece)
|
| 26 |
+
Context: 512
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### Lattice Attention
|
| 30 |
+
|
| 31 |
+
Standard transformer attention computes:
|
| 32 |
+
```
|
| 33 |
+
Attention(Q, K, V) = softmax(QK^T / √d) V
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Lattice attention blends this with Braille-aware similarity:
|
| 37 |
+
```
|
| 38 |
+
LatticeAttn = (1-λ) * StandardAttn + λ * HammingAttn
|
| 39 |
+
|
| 40 |
+
where HammingAttn[i,j] = 8 - popcount(token[i] XOR token[j])
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
This gives the model an inductive bias toward understanding Braille structure.
|
| 44 |
+
|
| 45 |
+
### Lattice Embeddings
|
| 46 |
+
|
| 47 |
+
For the first 256 tokens (corresponding to Braille cells), embeddings are initialized as:
|
| 48 |
+
```python
|
| 49 |
+
embedding[i] = Σ basis[b] for each raised dot b in cell i
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
This means similar Braille cells (low Hamming distance) start with similar embeddings.
|
| 53 |
+
|
| 54 |
+
### Morphological Regularization
|
| 55 |
+
|
| 56 |
+
Training includes a regularization term:
|
| 57 |
+
```
|
| 58 |
+
L_morph = ReLU(||emb - erode(emb)|| - ||emb - dilate(emb)||)
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
This encourages embeddings to respect the lattice ordering: `erode(x) ≤ x ≤ dilate(x)`.
|
| 62 |
+
|
| 63 |
+
## Theoretical Foundation
|
| 64 |
+
|
| 65 |
+
This model implements the formal theory from:
|
| 66 |
+
|
| 67 |
+
**"Theoretical Foundations for 8-Dot Braille-Native LLMs"**
|
| 68 |
+
|
| 69 |
+
Key theoretical components:
|
| 70 |
+
1. **Braille Lattice**: Boolean algebra (B⁸, ∧, ∨, ¬) with 256 elements
|
| 71 |
+
2. **Morphological Operators**: Erosion, dilation, opening, closing
|
| 72 |
+
3. **Modality-Invariant Representation**: (modality, sequence, embedding) triple
|
| 73 |
+
4. **Lattice Metrics**: Hamming distance, Jaccard similarity
|
| 74 |
+
|
| 75 |
+
See: `braille_lattice_theory.py` for full implementation.
|
| 76 |
+
|
| 77 |
+
## Modality Support
|
| 78 |
+
|
| 79 |
+
| Modality | Header | Status |
|
| 80 |
+
|----------|--------|--------|
|
| 81 |
+
| TEXT | ⣿⠁ | ✅ Trained |
|
| 82 |
+
| IMAGE | ⣿⠃ | ✅ Trained |
|
| 83 |
+
| AUDIO | ⣿⠇ | ✅ Trained |
|
| 84 |
+
| BINARY | ⣿⠏ | ✅ Trained |
|
| 85 |
+
| VIDEO | ⣿⠗ | 🔄 Framework ready |
|
| 86 |
+
|
| 87 |
+
## Haptic Evaluation Metrics
|
| 88 |
+
|
| 89 |
+
v6 introduces new evaluation metrics for tactile quality:
|
| 90 |
+
|
| 91 |
+
| Metric | Description | Target | **Achieved** |
|
| 92 |
+
|--------|-------------|--------|-------------|
|
| 93 |
+
| **Lattice Coherence** | Adjacent tokens have low Hamming distance | > 0.7 | **0.743** ✅ |
|
| 94 |
+
| **Morphological Stability** | Outputs stable under erosion/dilation | > 0.5 | **0.453** |
|
| 95 |
+
| **Haptic Score** | Combined tactile quality metric | > 0.5 | **0.598** ✅ |
|
| 96 |
+
|
| 97 |
+
## Training Results
|
| 98 |
+
|
| 99 |
+
| Metric | Value |
|
| 100 |
+
|--------|-------|
|
| 101 |
+
| Final Loss | 1.23 |
|
| 102 |
+
| Training Steps | 10,000 |
|
| 103 |
+
| Training Time | 2h 7m |
|
| 104 |
+
| Corpus | Balanced multimodal (25% each: text, image, audio, binary) |
|
| 105 |
+
| Corpus Size | 164M chars |
|
| 106 |
+
|
| 107 |
+
## Usage
|
| 108 |
+
|
| 109 |
+
```python
|
| 110 |
+
import torch
|
| 111 |
+
from train_lattice_v6 import Braille256LatticeModel, LatticeConfig
|
| 112 |
+
|
| 113 |
+
# Load model
|
| 114 |
+
config = LatticeConfig.from_dict(json.load(open("config.json")))
|
| 115 |
+
model = Braille256LatticeModel(config)
|
| 116 |
+
model.load_state_dict(torch.load("pytorch_model.bin"))
|
| 117 |
+
|
| 118 |
+
# Generate
|
| 119 |
+
input_ids = torch.tensor([[0x28, 0x29, 0x2A]]) # Some Braille tokens
|
| 120 |
+
output = model.generate(input_ids, max_length=100)
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
## Training
|
| 124 |
+
|
| 125 |
+
```bash
|
| 126 |
+
python train_lattice_v6.py \
|
| 127 |
+
--corpus corpus/braille_multimodal_corpus.txt \
|
| 128 |
+
--tokenizer tokenizers/braille_8dot_32k/braille_8dot_32k.model \
|
| 129 |
+
--output models/braille256_v6_lattice \
|
| 130 |
+
--steps 10000
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
### Training Options
|
| 134 |
+
|
| 135 |
+
| Flag | Description |
|
| 136 |
+
|------|-------------|
|
| 137 |
+
| `--no-lattice-attention` | Disable lattice attention (ablation) |
|
| 138 |
+
| `--no-lattice-embeddings` | Disable lattice embeddings (ablation) |
|
| 139 |
+
| `--no-morph-regularization` | Disable morphological regularization (ablation) |
|
| 140 |
+
|
| 141 |
+
## Model Family
|
| 142 |
+
|
| 143 |
+
| Version | Focus | Parameters | Key Feature |
|
| 144 |
+
|---------|-------|------------|-------------|
|
| 145 |
+
| v1-v3 | 6-dot Braille | ~10M | Basic Braille LM |
|
| 146 |
+
| v4 | 8-dot Braille | 29.9M | Full byte encoding |
|
| 147 |
+
| v5 | Multimodal | 11.5M | TEXT/IMAGE/AUDIO/BINARY |
|
| 148 |
+
| **v6** | **Lattice-aware** | **11.5M** | **Hamming attention, morphological regularization, balanced multimodal corpus** |
|
| 149 |
+
|
| 150 |
+
## Why Lattice-Aware?
|
| 151 |
+
|
| 152 |
+
Standard LLMs treat tokens as arbitrary symbols. braille256-v6 knows that:
|
| 153 |
+
|
| 154 |
+
1. **Braille cells form a lattice**: 256 elements with meet (∧) and join (∨)
|
| 155 |
+
2. **Similar cells should have similar representations**: Hamming distance matters
|
| 156 |
+
3. **Morphological operations preserve meaning**: Erosion/dilation are semantic
|
| 157 |
+
4. **Tactile quality is measurable**: Haptic metrics evaluate output quality
|
| 158 |
+
|
| 159 |
+
This makes v6 the first LLM designed for **tactile-first AI**.
|
| 160 |
+
|
| 161 |
+
## Citation
|
| 162 |
+
|
| 163 |
+
```bibtex
|
| 164 |
+
@misc{braille256v6,
|
| 165 |
+
author = {Barrett, Ryan},
|
| 166 |
+
title = {braille256-v6: Lattice-Aware Multimodal Braille Model},
|
| 167 |
+
year = {2024},
|
| 168 |
+
publisher = {HuggingFace},
|
| 169 |
+
url = {https://huggingface.co/ryanscottbarrett/braille256-v6}
|
| 170 |
+
}
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
## License
|
| 174 |
+
|
| 175 |
+
MIT
|
| 176 |
+
|
| 177 |
+
## Links
|
| 178 |
+
|
| 179 |
+
- [braille256-v5](https://huggingface.co/ryanscottbarrett/braille256-v5)
|
| 180 |
+
- [braille256-v4](https://huggingface.co/ryanscottbarrett/braille256-v4)
|
| 181 |
+
- [Theoretical Paper](docs/BRAILLE_NATIVE_LLM_THEORY.md)
|
| 182 |
+
- [Lattice Theory Implementation](src/braille_lattice_theory.py)
|
| 183 |
+
|
| 184 |
+
---
|
| 185 |
+
|
| 186 |
+
⣿ *The first LLM where Braille is not just the output format, but the computational substrate.* ⣿
|
braille_lattice_theory.py
ADDED
|
@@ -0,0 +1,1010 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Braille Dot-Lattice Theory: Formal Mathematical Framework
|
| 4 |
+
|
| 5 |
+
This module formalizes the missing theoretical components for 8-dot Braille-native LLMs:
|
| 6 |
+
|
| 7 |
+
1. DOT-LATTICE MORPHOLOGICAL OPERATORS
|
| 8 |
+
- Boolean algebra on the 8-bit Braille lattice (B⁸, ∧, ∨, ¬, ⊕)
|
| 9 |
+
- Morphological operations: erosion, dilation, opening, closing
|
| 10 |
+
- Dot-wise transformations preserving semantic structure
|
| 11 |
+
|
| 12 |
+
2. MODALITY-INVARIANT BRAILLE REASONING LOOPS
|
| 13 |
+
- Unified representation across text, image, audio, binary
|
| 14 |
+
- Cross-modal attention mechanisms in Braille space
|
| 15 |
+
- Semantic preservation under modality transformation
|
| 16 |
+
|
| 17 |
+
Mathematical Foundation:
|
| 18 |
+
- The 8-dot Braille cell forms a Boolean lattice (B⁸, ≤) where B = {0, 1}
|
| 19 |
+
- Each cell is an 8-dimensional binary vector: c ∈ {0,1}⁸
|
| 20 |
+
- The lattice has 2⁸ = 256 elements with meet (∧) and join (∨) operations
|
| 21 |
+
- This isomorphism to bytes enables direct computational semantics
|
| 22 |
+
|
| 23 |
+
Author: Ryan Barrett & Cascade
|
| 24 |
+
Date: December 2024
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
from dataclasses import dataclass, field
|
| 29 |
+
from typing import List, Dict, Tuple, Set, Callable, Optional, Iterator
|
| 30 |
+
from enum import Enum, auto
|
| 31 |
+
import numpy as np
|
| 32 |
+
from functools import reduce
|
| 33 |
+
import operator
|
| 34 |
+
|
| 35 |
+
# =============================================================================
|
| 36 |
+
# SECTION 1: BRAILLE LATTICE FUNDAMENTALS
|
| 37 |
+
# =============================================================================
|
| 38 |
+
|
| 39 |
+
# Unicode range for 8-dot Braille
|
| 40 |
+
BRAILLE_BASE = 0x2800
|
| 41 |
+
BRAILLE_MAX = 0x28FF
|
| 42 |
+
|
| 43 |
+
# Dot position bit values (standard 8-dot layout)
|
| 44 |
+
# Layout: 1 4
|
| 45 |
+
# 2 5
|
| 46 |
+
# 3 6
|
| 47 |
+
# 7 8
|
| 48 |
+
DOT_BITS = {
|
| 49 |
+
1: 0b00000001, # bit 0
|
| 50 |
+
2: 0b00000010, # bit 1
|
| 51 |
+
3: 0b00000100, # bit 2
|
| 52 |
+
4: 0b00001000, # bit 3
|
| 53 |
+
5: 0b00010000, # bit 4
|
| 54 |
+
6: 0b00100000, # bit 5
|
| 55 |
+
7: 0b01000000, # bit 6
|
| 56 |
+
8: 0b10000000, # bit 7
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
# Inverse mapping
|
| 60 |
+
BIT_TO_DOT = {v: k for k, v in DOT_BITS.items()}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass(frozen=True)
|
| 64 |
+
class BrailleCell:
|
| 65 |
+
"""
|
| 66 |
+
A single 8-dot Braille cell as an element of the Boolean lattice B⁸.
|
| 67 |
+
|
| 68 |
+
The cell is represented as an 8-bit integer where each bit corresponds
|
| 69 |
+
to a dot position. This enables efficient lattice operations.
|
| 70 |
+
|
| 71 |
+
Properties:
|
| 72 |
+
- Immutable (frozen dataclass)
|
| 73 |
+
- Hashable (can be used in sets/dicts)
|
| 74 |
+
- Supports all Boolean lattice operations
|
| 75 |
+
"""
|
| 76 |
+
value: int # 0-255, representing the 8 dots as bits
|
| 77 |
+
|
| 78 |
+
def __post_init__(self):
|
| 79 |
+
if not 0 <= self.value <= 255:
|
| 80 |
+
raise ValueError(f"BrailleCell value must be 0-255, got {self.value}")
|
| 81 |
+
|
| 82 |
+
# --- Lattice Element Properties ---
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def dots(self) -> Tuple[int, ...]:
|
| 86 |
+
"""Return tuple of active dot numbers (1-8)."""
|
| 87 |
+
return tuple(d for d in range(1, 9) if self.has_dot(d))
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def unicode(self) -> str:
|
| 91 |
+
"""Return the Unicode Braille character."""
|
| 92 |
+
return chr(BRAILLE_BASE + self.value)
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def vector(self) -> np.ndarray:
|
| 96 |
+
"""Return as 8-dimensional binary vector."""
|
| 97 |
+
return np.array([(self.value >> i) & 1 for i in range(8)], dtype=np.uint8)
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def cardinality(self) -> int:
|
| 101 |
+
"""Number of raised dots (Hamming weight)."""
|
| 102 |
+
return bin(self.value).count('1')
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def is_bottom(self) -> bool:
|
| 106 |
+
"""Check if this is ⊥ (empty cell, no dots)."""
|
| 107 |
+
return self.value == 0
|
| 108 |
+
|
| 109 |
+
@property
|
| 110 |
+
def is_top(self) -> bool:
|
| 111 |
+
"""Check if this is ⊤ (all dots raised)."""
|
| 112 |
+
return self.value == 255
|
| 113 |
+
|
| 114 |
+
def has_dot(self, dot: int) -> bool:
|
| 115 |
+
"""Check if a specific dot (1-8) is raised."""
|
| 116 |
+
return bool(self.value & DOT_BITS[dot])
|
| 117 |
+
|
| 118 |
+
# --- Boolean Lattice Operations ---
|
| 119 |
+
|
| 120 |
+
def meet(self, other: BrailleCell) -> BrailleCell:
|
| 121 |
+
"""
|
| 122 |
+
Lattice meet (∧): Greatest lower bound.
|
| 123 |
+
Equivalent to bitwise AND - keeps only dots present in BOTH cells.
|
| 124 |
+
|
| 125 |
+
Semantically: intersection of dot patterns.
|
| 126 |
+
"""
|
| 127 |
+
return BrailleCell(self.value & other.value)
|
| 128 |
+
|
| 129 |
+
def join(self, other: BrailleCell) -> BrailleCell:
|
| 130 |
+
"""
|
| 131 |
+
Lattice join (∨): Least upper bound.
|
| 132 |
+
Equivalent to bitwise OR - raises dots present in EITHER cell.
|
| 133 |
+
|
| 134 |
+
Semantically: union of dot patterns.
|
| 135 |
+
"""
|
| 136 |
+
return BrailleCell(self.value | other.value)
|
| 137 |
+
|
| 138 |
+
def complement(self) -> BrailleCell:
|
| 139 |
+
"""
|
| 140 |
+
Lattice complement (¬): Invert all dots.
|
| 141 |
+
Equivalent to bitwise NOT (masked to 8 bits).
|
| 142 |
+
|
| 143 |
+
Semantically: tactile negative.
|
| 144 |
+
"""
|
| 145 |
+
return BrailleCell((~self.value) & 0xFF)
|
| 146 |
+
|
| 147 |
+
def symmetric_difference(self, other: BrailleCell) -> BrailleCell:
|
| 148 |
+
"""
|
| 149 |
+
Symmetric difference (⊕): XOR operation.
|
| 150 |
+
Dots present in exactly one of the two cells.
|
| 151 |
+
|
| 152 |
+
Semantically: tactile contrast/difference.
|
| 153 |
+
"""
|
| 154 |
+
return BrailleCell(self.value ^ other.value)
|
| 155 |
+
|
| 156 |
+
def implies(self, other: BrailleCell) -> BrailleCell:
|
| 157 |
+
"""
|
| 158 |
+
Material implication (→): ¬self ∨ other.
|
| 159 |
+
In lattice terms: self ≤ other iff (self → other) = ⊤
|
| 160 |
+
"""
|
| 161 |
+
return self.complement().join(other)
|
| 162 |
+
|
| 163 |
+
# --- Partial Order ---
|
| 164 |
+
|
| 165 |
+
def __le__(self, other: BrailleCell) -> bool:
|
| 166 |
+
"""Lattice ordering: self ≤ other iff self ∧ other = self."""
|
| 167 |
+
return (self.value & other.value) == self.value
|
| 168 |
+
|
| 169 |
+
def __lt__(self, other: BrailleCell) -> bool:
|
| 170 |
+
"""Strict ordering: self < other iff self ≤ other and self ≠ other."""
|
| 171 |
+
return self <= other and self.value != other.value
|
| 172 |
+
|
| 173 |
+
def __ge__(self, other: BrailleCell) -> bool:
|
| 174 |
+
return other <= self
|
| 175 |
+
|
| 176 |
+
def __gt__(self, other: BrailleCell) -> bool:
|
| 177 |
+
return other < self
|
| 178 |
+
|
| 179 |
+
# --- Operator Overloads ---
|
| 180 |
+
|
| 181 |
+
def __and__(self, other: BrailleCell) -> BrailleCell:
|
| 182 |
+
return self.meet(other)
|
| 183 |
+
|
| 184 |
+
def __or__(self, other: BrailleCell) -> BrailleCell:
|
| 185 |
+
return self.join(other)
|
| 186 |
+
|
| 187 |
+
def __invert__(self) -> BrailleCell:
|
| 188 |
+
return self.complement()
|
| 189 |
+
|
| 190 |
+
def __xor__(self, other: BrailleCell) -> BrailleCell:
|
| 191 |
+
return self.symmetric_difference(other)
|
| 192 |
+
|
| 193 |
+
def __repr__(self) -> str:
|
| 194 |
+
return f"BrailleCell({self.unicode}, dots={self.dots}, value={self.value})"
|
| 195 |
+
|
| 196 |
+
# --- Constructors ---
|
| 197 |
+
|
| 198 |
+
@classmethod
|
| 199 |
+
def from_unicode(cls, char: str) -> BrailleCell:
|
| 200 |
+
"""Create from Unicode Braille character."""
|
| 201 |
+
code = ord(char)
|
| 202 |
+
if not BRAILLE_BASE <= code <= BRAILLE_MAX:
|
| 203 |
+
raise ValueError(f"Not a Braille character: {char}")
|
| 204 |
+
return cls(code - BRAILLE_BASE)
|
| 205 |
+
|
| 206 |
+
@classmethod
|
| 207 |
+
def from_dots(cls, *dots: int) -> BrailleCell:
|
| 208 |
+
"""Create from dot numbers (1-8)."""
|
| 209 |
+
value = 0
|
| 210 |
+
for d in dots:
|
| 211 |
+
if 1 <= d <= 8:
|
| 212 |
+
value |= DOT_BITS[d]
|
| 213 |
+
return cls(value)
|
| 214 |
+
|
| 215 |
+
@classmethod
|
| 216 |
+
def from_byte(cls, byte: int) -> BrailleCell:
|
| 217 |
+
"""Create from byte value (0-255)."""
|
| 218 |
+
return cls(byte & 0xFF)
|
| 219 |
+
|
| 220 |
+
@classmethod
|
| 221 |
+
def from_vector(cls, vec: np.ndarray) -> BrailleCell:
|
| 222 |
+
"""Create from 8-dimensional binary vector."""
|
| 223 |
+
value = sum(int(vec[i]) << i for i in range(8))
|
| 224 |
+
return cls(value)
|
| 225 |
+
|
| 226 |
+
@classmethod
|
| 227 |
+
def bottom(cls) -> BrailleCell:
|
| 228 |
+
"""Return ⊥ (empty cell)."""
|
| 229 |
+
return cls(0)
|
| 230 |
+
|
| 231 |
+
@classmethod
|
| 232 |
+
def top(cls) -> BrailleCell:
|
| 233 |
+
"""Return ⊤ (all dots raised)."""
|
| 234 |
+
return cls(255)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# =============================================================================
|
| 238 |
+
# SECTION 2: DOT-LATTICE MORPHOLOGICAL OPERATORS
|
| 239 |
+
# =============================================================================
|
| 240 |
+
|
| 241 |
+
class MorphologicalOperator(Enum):
|
| 242 |
+
"""Morphological operations on the Braille lattice."""
|
| 243 |
+
EROSION = auto() # Shrink patterns
|
| 244 |
+
DILATION = auto() # Expand patterns
|
| 245 |
+
OPENING = auto() # Erosion then dilation (remove small protrusions)
|
| 246 |
+
CLOSING = auto() # Dilation then erosion (fill small gaps)
|
| 247 |
+
GRADIENT = auto() # Dilation - Erosion (edge detection)
|
| 248 |
+
TOP_HAT = auto() # Original - Opening (extract bright details)
|
| 249 |
+
BLACK_HAT = auto() # Closing - Original (extract dark details)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
@dataclass
|
| 253 |
+
class StructuringElement:
|
| 254 |
+
"""
|
| 255 |
+
A structuring element for morphological operations on Braille cells.
|
| 256 |
+
|
| 257 |
+
In classical morphology, the structuring element defines the neighborhood.
|
| 258 |
+
For Braille, we define it as a set of dot positions that form the "kernel".
|
| 259 |
+
|
| 260 |
+
Common structuring elements:
|
| 261 |
+
- COLUMN_LEFT: dots 1,2,3,7 (left column)
|
| 262 |
+
- COLUMN_RIGHT: dots 4,5,6,8 (right column)
|
| 263 |
+
- ROW_TOP: dots 1,4 (top row)
|
| 264 |
+
- CROSS: dots 2,4,5 (cross pattern)
|
| 265 |
+
- FULL: all 8 dots
|
| 266 |
+
"""
|
| 267 |
+
dots: Set[int]
|
| 268 |
+
name: str = ""
|
| 269 |
+
|
| 270 |
+
@property
|
| 271 |
+
def cell(self) -> BrailleCell:
|
| 272 |
+
"""Convert to BrailleCell."""
|
| 273 |
+
return BrailleCell.from_dots(*self.dots)
|
| 274 |
+
|
| 275 |
+
# Predefined structuring elements
|
| 276 |
+
@classmethod
|
| 277 |
+
def column_left(cls) -> StructuringElement:
|
| 278 |
+
return cls({1, 2, 3, 7}, "COLUMN_LEFT")
|
| 279 |
+
|
| 280 |
+
@classmethod
|
| 281 |
+
def column_right(cls) -> StructuringElement:
|
| 282 |
+
return cls({4, 5, 6, 8}, "COLUMN_RIGHT")
|
| 283 |
+
|
| 284 |
+
@classmethod
|
| 285 |
+
def row_top(cls) -> StructuringElement:
|
| 286 |
+
return cls({1, 4}, "ROW_TOP")
|
| 287 |
+
|
| 288 |
+
@classmethod
|
| 289 |
+
def row_middle(cls) -> StructuringElement:
|
| 290 |
+
return cls({2, 5}, "ROW_MIDDLE")
|
| 291 |
+
|
| 292 |
+
@classmethod
|
| 293 |
+
def row_bottom(cls) -> StructuringElement:
|
| 294 |
+
return cls({3, 6}, "ROW_BOTTOM")
|
| 295 |
+
|
| 296 |
+
@classmethod
|
| 297 |
+
def row_extension(cls) -> StructuringElement:
|
| 298 |
+
return cls({7, 8}, "ROW_EXTENSION")
|
| 299 |
+
|
| 300 |
+
@classmethod
|
| 301 |
+
def cross(cls) -> StructuringElement:
|
| 302 |
+
return cls({2, 4, 5}, "CROSS")
|
| 303 |
+
|
| 304 |
+
@classmethod
|
| 305 |
+
def full(cls) -> StructuringElement:
|
| 306 |
+
return cls({1, 2, 3, 4, 5, 6, 7, 8}, "FULL")
|
| 307 |
+
|
| 308 |
+
@classmethod
|
| 309 |
+
def six_dot(cls) -> StructuringElement:
|
| 310 |
+
"""Traditional 6-dot Braille subset."""
|
| 311 |
+
return cls({1, 2, 3, 4, 5, 6}, "SIX_DOT")
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class BrailleMorphology:
|
| 315 |
+
"""
|
| 316 |
+
Morphological operators on the Braille dot-lattice.
|
| 317 |
+
|
| 318 |
+
These operators enable pattern transformation while preserving
|
| 319 |
+
structural relationships in the lattice.
|
| 320 |
+
|
| 321 |
+
Key insight: Morphological operations on Braille cells can be
|
| 322 |
+
computed efficiently using Boolean operations on the underlying
|
| 323 |
+
8-bit representation.
|
| 324 |
+
"""
|
| 325 |
+
|
| 326 |
+
@staticmethod
|
| 327 |
+
def erode(cell: BrailleCell, se: StructuringElement) -> BrailleCell:
|
| 328 |
+
"""
|
| 329 |
+
Erosion: Keep only dots that have ALL structuring element dots present.
|
| 330 |
+
|
| 331 |
+
ε_B(X) = {x : B_x ⊆ X}
|
| 332 |
+
|
| 333 |
+
For single cell: result has dot d iff for all dots s in SE,
|
| 334 |
+
the cell has dot at position (d + s - 1) mod 8 + 1
|
| 335 |
+
|
| 336 |
+
Simplified for single cell: AND with structuring element.
|
| 337 |
+
"""
|
| 338 |
+
return cell & se.cell
|
| 339 |
+
|
| 340 |
+
@staticmethod
|
| 341 |
+
def dilate(cell: BrailleCell, se: StructuringElement) -> BrailleCell:
|
| 342 |
+
"""
|
| 343 |
+
Dilation: Raise dots if ANY structuring element dot is present.
|
| 344 |
+
|
| 345 |
+
δ_B(X) = {x : B_x ∩ X ≠ ∅}
|
| 346 |
+
|
| 347 |
+
For single cell: OR with structuring element.
|
| 348 |
+
"""
|
| 349 |
+
return cell | se.cell
|
| 350 |
+
|
| 351 |
+
@staticmethod
|
| 352 |
+
def opening(cell: BrailleCell, se: StructuringElement) -> BrailleCell:
|
| 353 |
+
"""
|
| 354 |
+
Opening: Erosion followed by dilation.
|
| 355 |
+
|
| 356 |
+
γ_B(X) = δ_B(ε_B(X))
|
| 357 |
+
|
| 358 |
+
Effect: Removes small protrusions, smooths from outside.
|
| 359 |
+
"""
|
| 360 |
+
eroded = BrailleMorphology.erode(cell, se)
|
| 361 |
+
return BrailleMorphology.dilate(eroded, se)
|
| 362 |
+
|
| 363 |
+
@staticmethod
|
| 364 |
+
def closing(cell: BrailleCell, se: StructuringElement) -> BrailleCell:
|
| 365 |
+
"""
|
| 366 |
+
Closing: Dilation followed by erosion.
|
| 367 |
+
|
| 368 |
+
φ_B(X) = ε_B(δ_B(X))
|
| 369 |
+
|
| 370 |
+
Effect: Fills small gaps, smooths from inside.
|
| 371 |
+
"""
|
| 372 |
+
dilated = BrailleMorphology.dilate(cell, se)
|
| 373 |
+
return BrailleMorphology.erode(dilated, se)
|
| 374 |
+
|
| 375 |
+
@staticmethod
|
| 376 |
+
def gradient(cell: BrailleCell, se: StructuringElement) -> BrailleCell:
|
| 377 |
+
"""
|
| 378 |
+
Morphological gradient: Dilation - Erosion (via XOR).
|
| 379 |
+
|
| 380 |
+
ρ_B(X) = δ_B(X) - ε_B(X)
|
| 381 |
+
|
| 382 |
+
Effect: Edge detection - dots that differ between dilation and erosion.
|
| 383 |
+
"""
|
| 384 |
+
dilated = BrailleMorphology.dilate(cell, se)
|
| 385 |
+
eroded = BrailleMorphology.erode(cell, se)
|
| 386 |
+
return dilated ^ eroded
|
| 387 |
+
|
| 388 |
+
@staticmethod
|
| 389 |
+
def top_hat(cell: BrailleCell, se: StructuringElement) -> BrailleCell:
|
| 390 |
+
"""
|
| 391 |
+
Top-hat transform: Original - Opening.
|
| 392 |
+
|
| 393 |
+
Effect: Extracts bright details smaller than structuring element.
|
| 394 |
+
"""
|
| 395 |
+
opened = BrailleMorphology.opening(cell, se)
|
| 396 |
+
return cell ^ opened # Difference via XOR
|
| 397 |
+
|
| 398 |
+
@staticmethod
|
| 399 |
+
def black_hat(cell: BrailleCell, se: StructuringElement) -> BrailleCell:
|
| 400 |
+
"""
|
| 401 |
+
Black-hat transform: Closing - Original.
|
| 402 |
+
|
| 403 |
+
Effect: Extracts dark details smaller than structuring element.
|
| 404 |
+
"""
|
| 405 |
+
closed = BrailleMorphology.closing(cell, se)
|
| 406 |
+
return closed ^ cell # Difference via XOR
|
| 407 |
+
|
| 408 |
+
@staticmethod
|
| 409 |
+
def hit_or_miss(cell: BrailleCell,
|
| 410 |
+
foreground: StructuringElement,
|
| 411 |
+
background: StructuringElement) -> bool:
|
| 412 |
+
"""
|
| 413 |
+
Hit-or-miss transform: Pattern matching.
|
| 414 |
+
|
| 415 |
+
Returns True iff:
|
| 416 |
+
- All foreground dots are present in cell
|
| 417 |
+
- All background dots are absent from cell
|
| 418 |
+
|
| 419 |
+
This is the foundation for pattern recognition in Braille.
|
| 420 |
+
"""
|
| 421 |
+
fg_match = (cell & foreground.cell) == foreground.cell
|
| 422 |
+
bg_match = (cell & background.cell).is_bottom
|
| 423 |
+
return fg_match and bg_match
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
# =============================================================================
|
| 427 |
+
# SECTION 3: BRAILLE SEQUENCE MORPHOLOGY
|
| 428 |
+
# =============================================================================
|
| 429 |
+
|
| 430 |
+
@dataclass
|
| 431 |
+
class BrailleSequence:
|
| 432 |
+
"""
|
| 433 |
+
A sequence of Braille cells with morphological operations.
|
| 434 |
+
|
| 435 |
+
This extends single-cell morphology to sequences, enabling
|
| 436 |
+
operations on Braille text/data streams.
|
| 437 |
+
"""
|
| 438 |
+
cells: List[BrailleCell] = field(default_factory=list)
|
| 439 |
+
|
| 440 |
+
def __len__(self) -> int:
|
| 441 |
+
return len(self.cells)
|
| 442 |
+
|
| 443 |
+
def __getitem__(self, idx: int) -> BrailleCell:
|
| 444 |
+
return self.cells[idx]
|
| 445 |
+
|
| 446 |
+
def __iter__(self) -> Iterator[BrailleCell]:
|
| 447 |
+
return iter(self.cells)
|
| 448 |
+
|
| 449 |
+
@property
|
| 450 |
+
def unicode(self) -> str:
|
| 451 |
+
"""Return as Unicode Braille string."""
|
| 452 |
+
return ''.join(c.unicode for c in self.cells)
|
| 453 |
+
|
| 454 |
+
@property
|
| 455 |
+
def bytes(self) -> bytes:
|
| 456 |
+
"""Return as byte sequence."""
|
| 457 |
+
return bytes(c.value for c in self.cells)
|
| 458 |
+
|
| 459 |
+
def apply(self, op: Callable[[BrailleCell], BrailleCell]) -> BrailleSequence:
|
| 460 |
+
"""Apply a cell-wise operation to the sequence."""
|
| 461 |
+
return BrailleSequence([op(c) for c in self.cells])
|
| 462 |
+
|
| 463 |
+
def apply_morphology(self,
|
| 464 |
+
operator: MorphologicalOperator,
|
| 465 |
+
se: StructuringElement) -> BrailleSequence:
|
| 466 |
+
"""Apply morphological operation to each cell."""
|
| 467 |
+
ops = {
|
| 468 |
+
MorphologicalOperator.EROSION: BrailleMorphology.erode,
|
| 469 |
+
MorphologicalOperator.DILATION: BrailleMorphology.dilate,
|
| 470 |
+
MorphologicalOperator.OPENING: BrailleMorphology.opening,
|
| 471 |
+
MorphologicalOperator.CLOSING: BrailleMorphology.closing,
|
| 472 |
+
MorphologicalOperator.GRADIENT: BrailleMorphology.gradient,
|
| 473 |
+
MorphologicalOperator.TOP_HAT: BrailleMorphology.top_hat,
|
| 474 |
+
MorphologicalOperator.BLACK_HAT: BrailleMorphology.black_hat,
|
| 475 |
+
}
|
| 476 |
+
op_func = ops[operator]
|
| 477 |
+
return BrailleSequence([op_func(c, se) for c in self.cells])
|
| 478 |
+
|
| 479 |
+
def convolve(self, kernel: List[BrailleCell],
|
| 480 |
+
op: Callable[[BrailleCell, BrailleCell], BrailleCell] = lambda a, b: a & b) -> BrailleSequence:
|
| 481 |
+
"""
|
| 482 |
+
Convolve sequence with a kernel using specified operation.
|
| 483 |
+
|
| 484 |
+
This enables sliding-window pattern matching and transformation.
|
| 485 |
+
"""
|
| 486 |
+
if not kernel:
|
| 487 |
+
return self
|
| 488 |
+
|
| 489 |
+
k_len = len(kernel)
|
| 490 |
+
result = []
|
| 491 |
+
|
| 492 |
+
for i in range(len(self.cells)):
|
| 493 |
+
# Apply kernel centered at position i
|
| 494 |
+
acc = BrailleCell.bottom()
|
| 495 |
+
for j, k_cell in enumerate(kernel):
|
| 496 |
+
idx = i - k_len // 2 + j
|
| 497 |
+
if 0 <= idx < len(self.cells):
|
| 498 |
+
acc = acc | op(self.cells[idx], k_cell)
|
| 499 |
+
result.append(acc)
|
| 500 |
+
|
| 501 |
+
return BrailleSequence(result)
|
| 502 |
+
|
| 503 |
+
def reduce(self,
|
| 504 |
+
op: Callable[[BrailleCell, BrailleCell], BrailleCell] = lambda a, b: a | b) -> BrailleCell:
|
| 505 |
+
"""Reduce sequence to single cell using operation."""
|
| 506 |
+
if not self.cells:
|
| 507 |
+
return BrailleCell.bottom()
|
| 508 |
+
return reduce(op, self.cells)
|
| 509 |
+
|
| 510 |
+
@classmethod
|
| 511 |
+
def from_unicode(cls, text: str) -> BrailleSequence:
|
| 512 |
+
"""Create from Unicode Braille string."""
|
| 513 |
+
cells = []
|
| 514 |
+
for char in text:
|
| 515 |
+
code = ord(char)
|
| 516 |
+
if BRAILLE_BASE <= code <= BRAILLE_MAX:
|
| 517 |
+
cells.append(BrailleCell(code - BRAILLE_BASE))
|
| 518 |
+
return cls(cells)
|
| 519 |
+
|
| 520 |
+
@classmethod
|
| 521 |
+
def from_bytes(cls, data: bytes) -> BrailleSequence:
|
| 522 |
+
"""Create from byte sequence (direct mapping)."""
|
| 523 |
+
return cls([BrailleCell(b) for b in data])
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
# =============================================================================
|
| 527 |
+
# SECTION 4: MODALITY-INVARIANT BRAILLE REPRESENTATION
|
| 528 |
+
# =============================================================================
|
| 529 |
+
|
| 530 |
+
class Modality(Enum):
|
| 531 |
+
"""Supported modalities for Braille encoding."""
|
| 532 |
+
TEXT = auto()
|
| 533 |
+
IMAGE = auto()
|
| 534 |
+
AUDIO = auto()
|
| 535 |
+
BINARY = auto()
|
| 536 |
+
VIDEO = auto()
|
| 537 |
+
SEMANTIC = auto() # Abstract semantic content
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
# Modality headers (from braille256-v5)
|
| 541 |
+
MODALITY_HEADERS = {
|
| 542 |
+
Modality.TEXT: BrailleCell.from_dots(1, 2, 3, 4, 5, 6, 7, 8), # ⣿ + ⠁ = ⣿⠁
|
| 543 |
+
Modality.IMAGE: BrailleCell.from_dots(1, 2, 3, 4, 5, 6, 7, 8), # ⣿ + ⠃ = ⣿⠃
|
| 544 |
+
Modality.AUDIO: BrailleCell.from_dots(1, 2, 3, 4, 5, 6, 7, 8), # ⣿ + ⠇ = ⣿⠇
|
| 545 |
+
Modality.BINARY: BrailleCell.from_dots(1, 2, 3, 4, 5, 6, 7, 8), # ⣿ + ⠏ = ⣿⠏
|
| 546 |
+
Modality.VIDEO: BrailleCell.from_dots(1, 2, 3, 4, 5, 6, 7, 8), # ⣿ + ⠗ = ⣿⠗
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
@dataclass
|
| 551 |
+
class ModalityInvariantRepresentation:
|
| 552 |
+
"""
|
| 553 |
+
A modality-invariant representation in Braille space.
|
| 554 |
+
|
| 555 |
+
Key insight: All modalities can be encoded as byte sequences,
|
| 556 |
+
and all byte sequences map bijectively to 8-dot Braille.
|
| 557 |
+
Therefore, Braille provides a universal representation space.
|
| 558 |
+
|
| 559 |
+
The representation consists of:
|
| 560 |
+
1. Modality header (identifies source modality)
|
| 561 |
+
2. Semantic embedding (modality-invariant meaning)
|
| 562 |
+
3. Raw Braille sequence (the actual data)
|
| 563 |
+
|
| 564 |
+
Cross-modal operations preserve semantic content while
|
| 565 |
+
allowing modality-specific transformations.
|
| 566 |
+
"""
|
| 567 |
+
modality: Modality
|
| 568 |
+
sequence: BrailleSequence
|
| 569 |
+
semantic_embedding: Optional[np.ndarray] = None # d-dimensional semantic vector
|
| 570 |
+
metadata: Dict = field(default_factory=dict)
|
| 571 |
+
|
| 572 |
+
@property
|
| 573 |
+
def header(self) -> BrailleCell:
|
| 574 |
+
"""Get modality header cell."""
|
| 575 |
+
return MODALITY_HEADERS.get(self.modality, BrailleCell.top())
|
| 576 |
+
|
| 577 |
+
def to_semantic_space(self, encoder: Callable[[BrailleSequence], np.ndarray]) -> np.ndarray:
|
| 578 |
+
"""
|
| 579 |
+
Project Braille sequence to semantic embedding space.
|
| 580 |
+
|
| 581 |
+
This is where the LLM's learned embeddings come in.
|
| 582 |
+
The encoder maps Braille tokens to semantic vectors.
|
| 583 |
+
"""
|
| 584 |
+
if self.semantic_embedding is None:
|
| 585 |
+
self.semantic_embedding = encoder(self.sequence)
|
| 586 |
+
return self.semantic_embedding
|
| 587 |
+
|
| 588 |
+
def transform_modality(self,
|
| 589 |
+
target: Modality,
|
| 590 |
+
transformer: Callable[[BrailleSequence, Modality, Modality], BrailleSequence]
|
| 591 |
+
) -> ModalityInvariantRepresentation:
|
| 592 |
+
"""
|
| 593 |
+
Transform to a different modality while preserving semantics.
|
| 594 |
+
|
| 595 |
+
The transformer function handles modality-specific conversion
|
| 596 |
+
while the semantic embedding remains invariant.
|
| 597 |
+
"""
|
| 598 |
+
new_sequence = transformer(self.sequence, self.modality, target)
|
| 599 |
+
return ModalityInvariantRepresentation(
|
| 600 |
+
modality=target,
|
| 601 |
+
sequence=new_sequence,
|
| 602 |
+
semantic_embedding=self.semantic_embedding, # Preserved!
|
| 603 |
+
metadata={**self.metadata, 'source_modality': self.modality}
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
# =============================================================================
|
| 608 |
+
# SECTION 5: BRAILLE REASONING LOOPS
|
| 609 |
+
# =============================================================================
|
| 610 |
+
|
| 611 |
+
@dataclass
|
| 612 |
+
class ReasoningState:
|
| 613 |
+
"""
|
| 614 |
+
State of a Braille reasoning loop.
|
| 615 |
+
|
| 616 |
+
The reasoning loop operates entirely in Braille space:
|
| 617 |
+
1. Input: Braille sequence (any modality)
|
| 618 |
+
2. Transform: Apply morphological/semantic operations
|
| 619 |
+
3. Attend: Cross-modal attention in Braille space
|
| 620 |
+
4. Output: Braille sequence (any modality)
|
| 621 |
+
|
| 622 |
+
This enables modality-invariant reasoning where the same
|
| 623 |
+
operations work regardless of input/output modality.
|
| 624 |
+
"""
|
| 625 |
+
sequence: BrailleSequence
|
| 626 |
+
attention_weights: Optional[np.ndarray] = None
|
| 627 |
+
hidden_state: Optional[np.ndarray] = None
|
| 628 |
+
step: int = 0
|
| 629 |
+
|
| 630 |
+
def apply_attention(self,
|
| 631 |
+
query: BrailleSequence,
|
| 632 |
+
key: BrailleSequence,
|
| 633 |
+
value: BrailleSequence) -> BrailleSequence:
|
| 634 |
+
"""
|
| 635 |
+
Cross-modal attention in Braille space.
|
| 636 |
+
|
| 637 |
+
Attention is computed on the lattice structure:
|
| 638 |
+
- Query, Key, Value are all Braille sequences
|
| 639 |
+
- Similarity is measured via lattice distance
|
| 640 |
+
- Output is weighted combination in Braille space
|
| 641 |
+
|
| 642 |
+
Lattice distance: d(a, b) = |a ⊕ b| (Hamming distance)
|
| 643 |
+
"""
|
| 644 |
+
if len(query) == 0 or len(key) == 0:
|
| 645 |
+
return value
|
| 646 |
+
|
| 647 |
+
# Compute attention scores based on lattice similarity
|
| 648 |
+
scores = np.zeros((len(query), len(key)))
|
| 649 |
+
for i, q in enumerate(query):
|
| 650 |
+
for j, k in enumerate(key):
|
| 651 |
+
# Similarity = 8 - Hamming distance (higher = more similar)
|
| 652 |
+
diff = q ^ k
|
| 653 |
+
scores[i, j] = 8 - diff.cardinality
|
| 654 |
+
|
| 655 |
+
# Softmax normalization
|
| 656 |
+
scores = np.exp(scores - scores.max(axis=1, keepdims=True))
|
| 657 |
+
self.attention_weights = scores / scores.sum(axis=1, keepdims=True)
|
| 658 |
+
|
| 659 |
+
# Weighted combination of values
|
| 660 |
+
result = []
|
| 661 |
+
for i in range(len(query)):
|
| 662 |
+
# Combine values weighted by attention
|
| 663 |
+
combined = BrailleCell.bottom()
|
| 664 |
+
for j, v in enumerate(value):
|
| 665 |
+
if self.attention_weights[i, j] > 0.1: # Threshold
|
| 666 |
+
combined = combined | v
|
| 667 |
+
result.append(combined)
|
| 668 |
+
|
| 669 |
+
return BrailleSequence(result)
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
class BrailleReasoningLoop:
|
| 673 |
+
"""
|
| 674 |
+
A modality-invariant reasoning loop operating in Braille space.
|
| 675 |
+
|
| 676 |
+
The loop implements the following cycle:
|
| 677 |
+
|
| 678 |
+
1. ENCODE: Any modality → Braille sequence
|
| 679 |
+
2. TRANSFORM: Morphological operations on Braille
|
| 680 |
+
3. ATTEND: Cross-sequence attention in lattice space
|
| 681 |
+
4. REASON: Apply learned transformations (LLM layers)
|
| 682 |
+
5. DECODE: Braille sequence → Any modality
|
| 683 |
+
|
| 684 |
+
Key property: Steps 2-4 are MODALITY-INVARIANT.
|
| 685 |
+
The same operations work for text, images, audio, etc.
|
| 686 |
+
"""
|
| 687 |
+
|
| 688 |
+
def __init__(self,
|
| 689 |
+
hidden_dim: int = 256,
|
| 690 |
+
num_heads: int = 8,
|
| 691 |
+
morphology_se: StructuringElement = None):
|
| 692 |
+
self.hidden_dim = hidden_dim
|
| 693 |
+
self.num_heads = num_heads
|
| 694 |
+
self.morphology_se = morphology_se or StructuringElement.six_dot()
|
| 695 |
+
self.state = None
|
| 696 |
+
|
| 697 |
+
def encode(self,
|
| 698 |
+
data: bytes,
|
| 699 |
+
modality: Modality) -> ModalityInvariantRepresentation:
|
| 700 |
+
"""
|
| 701 |
+
Encode any modality to Braille representation.
|
| 702 |
+
|
| 703 |
+
This is the entry point: raw bytes → Braille sequence.
|
| 704 |
+
The modality header is prepended for downstream processing.
|
| 705 |
+
"""
|
| 706 |
+
sequence = BrailleSequence.from_bytes(data)
|
| 707 |
+
return ModalityInvariantRepresentation(
|
| 708 |
+
modality=modality,
|
| 709 |
+
sequence=sequence
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
def transform(self,
|
| 713 |
+
rep: ModalityInvariantRepresentation,
|
| 714 |
+
operator: MorphologicalOperator = MorphologicalOperator.OPENING
|
| 715 |
+
) -> ModalityInvariantRepresentation:
|
| 716 |
+
"""
|
| 717 |
+
Apply morphological transformation.
|
| 718 |
+
|
| 719 |
+
This step is modality-invariant: the same operation
|
| 720 |
+
works regardless of whether the input is text, image, etc.
|
| 721 |
+
"""
|
| 722 |
+
transformed = rep.sequence.apply_morphology(operator, self.morphology_se)
|
| 723 |
+
return ModalityInvariantRepresentation(
|
| 724 |
+
modality=rep.modality,
|
| 725 |
+
sequence=transformed,
|
| 726 |
+
semantic_embedding=rep.semantic_embedding,
|
| 727 |
+
metadata=rep.metadata
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
def attend(self,
|
| 731 |
+
query_rep: ModalityInvariantRepresentation,
|
| 732 |
+
context_rep: ModalityInvariantRepresentation
|
| 733 |
+
) -> ModalityInvariantRepresentation:
|
| 734 |
+
"""
|
| 735 |
+
Cross-modal attention between two representations.
|
| 736 |
+
|
| 737 |
+
This enables reasoning across modalities:
|
| 738 |
+
- Text attending to image
|
| 739 |
+
- Audio attending to text
|
| 740 |
+
- Any modality attending to any other
|
| 741 |
+
|
| 742 |
+
The attention operates in Braille lattice space.
|
| 743 |
+
"""
|
| 744 |
+
if self.state is None:
|
| 745 |
+
self.state = ReasoningState(sequence=query_rep.sequence)
|
| 746 |
+
|
| 747 |
+
attended = self.state.apply_attention(
|
| 748 |
+
query=query_rep.sequence,
|
| 749 |
+
key=context_rep.sequence,
|
| 750 |
+
value=context_rep.sequence
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
return ModalityInvariantRepresentation(
|
| 754 |
+
modality=query_rep.modality,
|
| 755 |
+
sequence=attended,
|
| 756 |
+
semantic_embedding=query_rep.semantic_embedding,
|
| 757 |
+
metadata={**query_rep.metadata, 'attended_modality': context_rep.modality}
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
def reason(self,
|
| 761 |
+
rep: ModalityInvariantRepresentation,
|
| 762 |
+
transform_fn: Callable[[BrailleSequence], BrailleSequence] = None
|
| 763 |
+
) -> ModalityInvariantRepresentation:
|
| 764 |
+
"""
|
| 765 |
+
Apply learned reasoning transformation.
|
| 766 |
+
|
| 767 |
+
In a full LLM, this would be the transformer layers.
|
| 768 |
+
Here we provide a hook for custom transformations.
|
| 769 |
+
|
| 770 |
+
The key insight: reasoning happens in Braille space,
|
| 771 |
+
making it inherently modality-invariant.
|
| 772 |
+
"""
|
| 773 |
+
if transform_fn is None:
|
| 774 |
+
# Default: identity with morphological smoothing
|
| 775 |
+
transform_fn = lambda seq: seq.apply_morphology(
|
| 776 |
+
MorphologicalOperator.CLOSING,
|
| 777 |
+
self.morphology_se
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
reasoned = transform_fn(rep.sequence)
|
| 781 |
+
|
| 782 |
+
return ModalityInvariantRepresentation(
|
| 783 |
+
modality=rep.modality,
|
| 784 |
+
sequence=reasoned,
|
| 785 |
+
semantic_embedding=rep.semantic_embedding,
|
| 786 |
+
metadata=rep.metadata
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
def decode(self,
|
| 790 |
+
rep: ModalityInvariantRepresentation,
|
| 791 |
+
target_modality: Modality = None
|
| 792 |
+
) -> bytes:
|
| 793 |
+
"""
|
| 794 |
+
Decode Braille representation to bytes.
|
| 795 |
+
|
| 796 |
+
This is the exit point: Braille sequence → raw bytes.
|
| 797 |
+
The target modality determines any post-processing.
|
| 798 |
+
"""
|
| 799 |
+
return rep.sequence.bytes
|
| 800 |
+
|
| 801 |
+
def full_loop(self,
|
| 802 |
+
input_data: bytes,
|
| 803 |
+
input_modality: Modality,
|
| 804 |
+
context_data: bytes = None,
|
| 805 |
+
context_modality: Modality = None,
|
| 806 |
+
output_modality: Modality = None
|
| 807 |
+
) -> bytes:
|
| 808 |
+
"""
|
| 809 |
+
Execute a complete reasoning loop.
|
| 810 |
+
|
| 811 |
+
Input → Encode → Transform → Attend → Reason → Decode → Output
|
| 812 |
+
|
| 813 |
+
All intermediate steps are modality-invariant.
|
| 814 |
+
"""
|
| 815 |
+
# Encode input
|
| 816 |
+
rep = self.encode(input_data, input_modality)
|
| 817 |
+
|
| 818 |
+
# Transform
|
| 819 |
+
rep = self.transform(rep)
|
| 820 |
+
|
| 821 |
+
# Attend to context if provided
|
| 822 |
+
if context_data is not None:
|
| 823 |
+
context_rep = self.encode(
|
| 824 |
+
context_data,
|
| 825 |
+
context_modality or input_modality
|
| 826 |
+
)
|
| 827 |
+
rep = self.attend(rep, context_rep)
|
| 828 |
+
|
| 829 |
+
# Reason
|
| 830 |
+
rep = self.reason(rep)
|
| 831 |
+
|
| 832 |
+
# Decode
|
| 833 |
+
return self.decode(rep, output_modality or input_modality)
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
# =============================================================================
|
| 837 |
+
# SECTION 6: LATTICE DISTANCE METRICS
|
| 838 |
+
# =============================================================================
|
| 839 |
+
|
| 840 |
+
class BrailleMetrics:
|
| 841 |
+
"""
|
| 842 |
+
Distance and similarity metrics on the Braille lattice.
|
| 843 |
+
|
| 844 |
+
These metrics enable:
|
| 845 |
+
- Semantic similarity measurement
|
| 846 |
+
- Clustering in Braille space
|
| 847 |
+
- Loss functions for training
|
| 848 |
+
"""
|
| 849 |
+
|
| 850 |
+
@staticmethod
|
| 851 |
+
def hamming_distance(a: BrailleCell, b: BrailleCell) -> int:
|
| 852 |
+
"""
|
| 853 |
+
Hamming distance: number of differing dots.
|
| 854 |
+
|
| 855 |
+
d_H(a, b) = |a ⊕ b| = popcount(a XOR b)
|
| 856 |
+
|
| 857 |
+
Range: [0, 8]
|
| 858 |
+
"""
|
| 859 |
+
return (a ^ b).cardinality
|
| 860 |
+
|
| 861 |
+
@staticmethod
|
| 862 |
+
def jaccard_similarity(a: BrailleCell, b: BrailleCell) -> float:
|
| 863 |
+
"""
|
| 864 |
+
Jaccard similarity: intersection over union.
|
| 865 |
+
|
| 866 |
+
J(a, b) = |a ∧ b| / |a ∨ b|
|
| 867 |
+
|
| 868 |
+
Range: [0, 1]
|
| 869 |
+
"""
|
| 870 |
+
intersection = (a & b).cardinality
|
| 871 |
+
union = (a | b).cardinality
|
| 872 |
+
if union == 0:
|
| 873 |
+
return 1.0 # Both empty
|
| 874 |
+
return intersection / union
|
| 875 |
+
|
| 876 |
+
@staticmethod
|
| 877 |
+
def lattice_distance(a: BrailleCell, b: BrailleCell) -> int:
|
| 878 |
+
"""
|
| 879 |
+
Lattice distance: length of shortest path in Hasse diagram.
|
| 880 |
+
|
| 881 |
+
For Boolean lattice: d_L(a, b) = |a ⊕ b| (same as Hamming)
|
| 882 |
+
"""
|
| 883 |
+
return BrailleMetrics.hamming_distance(a, b)
|
| 884 |
+
|
| 885 |
+
@staticmethod
|
| 886 |
+
def semantic_distance(a: BrailleCell, b: BrailleCell,
|
| 887 |
+
embeddings: Dict[int, np.ndarray] = None) -> float:
|
| 888 |
+
"""
|
| 889 |
+
Semantic distance using learned embeddings.
|
| 890 |
+
|
| 891 |
+
If embeddings are provided, uses cosine distance in embedding space.
|
| 892 |
+
Otherwise, falls back to normalized Hamming distance.
|
| 893 |
+
"""
|
| 894 |
+
if embeddings is not None and a.value in embeddings and b.value in embeddings:
|
| 895 |
+
vec_a = embeddings[a.value]
|
| 896 |
+
vec_b = embeddings[b.value]
|
| 897 |
+
cos_sim = np.dot(vec_a, vec_b) / (np.linalg.norm(vec_a) * np.linalg.norm(vec_b))
|
| 898 |
+
return 1.0 - cos_sim
|
| 899 |
+
else:
|
| 900 |
+
return BrailleMetrics.hamming_distance(a, b) / 8.0
|
| 901 |
+
|
| 902 |
+
@staticmethod
|
| 903 |
+
def sequence_distance(a: BrailleSequence, b: BrailleSequence,
|
| 904 |
+
cell_metric: Callable[[BrailleCell, BrailleCell], float] = None
|
| 905 |
+
) -> float:
|
| 906 |
+
"""
|
| 907 |
+
Distance between two Braille sequences.
|
| 908 |
+
|
| 909 |
+
Uses dynamic time warping or simple alignment depending on lengths.
|
| 910 |
+
"""
|
| 911 |
+
if cell_metric is None:
|
| 912 |
+
cell_metric = lambda x, y: BrailleMetrics.hamming_distance(x, y) / 8.0
|
| 913 |
+
|
| 914 |
+
if len(a) == 0 and len(b) == 0:
|
| 915 |
+
return 0.0
|
| 916 |
+
if len(a) == 0 or len(b) == 0:
|
| 917 |
+
return 1.0
|
| 918 |
+
|
| 919 |
+
# Simple aligned distance for equal lengths
|
| 920 |
+
if len(a) == len(b):
|
| 921 |
+
total = sum(cell_metric(a[i], b[i]) for i in range(len(a)))
|
| 922 |
+
return total / len(a)
|
| 923 |
+
|
| 924 |
+
# DTW for unequal lengths
|
| 925 |
+
n, m = len(a), len(b)
|
| 926 |
+
dtw = np.full((n + 1, m + 1), np.inf)
|
| 927 |
+
dtw[0, 0] = 0
|
| 928 |
+
|
| 929 |
+
for i in range(1, n + 1):
|
| 930 |
+
for j in range(1, m + 1):
|
| 931 |
+
cost = cell_metric(a[i-1], b[j-1])
|
| 932 |
+
dtw[i, j] = cost + min(dtw[i-1, j], dtw[i, j-1], dtw[i-1, j-1])
|
| 933 |
+
|
| 934 |
+
return dtw[n, m] / max(n, m)
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
# =============================================================================
|
| 938 |
+
# SECTION 7: DEMONSTRATION AND TESTING
|
| 939 |
+
# =============================================================================
|
| 940 |
+
|
| 941 |
+
def demonstrate_lattice_operations():
|
| 942 |
+
"""Demonstrate the Braille lattice operations."""
|
| 943 |
+
print("=" * 60)
|
| 944 |
+
print("BRAILLE DOT-LATTICE THEORY DEMONSTRATION")
|
| 945 |
+
print("=" * 60)
|
| 946 |
+
|
| 947 |
+
# Create some cells
|
| 948 |
+
cell_a = BrailleCell.from_dots(1, 2, 4) # ⠋
|
| 949 |
+
cell_b = BrailleCell.from_dots(2, 4, 5) # ⠚
|
| 950 |
+
|
| 951 |
+
print(f"\n1. BASIC LATTICE OPERATIONS")
|
| 952 |
+
print(f" Cell A: {cell_a}")
|
| 953 |
+
print(f" Cell B: {cell_b}")
|
| 954 |
+
print(f" A ∧ B (meet): {cell_a & cell_b}")
|
| 955 |
+
print(f" A ∨ B (join): {cell_a | cell_b}")
|
| 956 |
+
print(f" ¬A (complement): {~cell_a}")
|
| 957 |
+
print(f" A ⊕ B (xor): {cell_a ^ cell_b}")
|
| 958 |
+
print(f" A ≤ B: {cell_a <= cell_b}")
|
| 959 |
+
|
| 960 |
+
# Morphological operations
|
| 961 |
+
print(f"\n2. MORPHOLOGICAL OPERATIONS")
|
| 962 |
+
se = StructuringElement.column_left()
|
| 963 |
+
print(f" Structuring element: {se.name} = dots {se.dots}")
|
| 964 |
+
print(f" Erosion(A, SE): {BrailleMorphology.erode(cell_a, se)}")
|
| 965 |
+
print(f" Dilation(A, SE): {BrailleMorphology.dilate(cell_a, se)}")
|
| 966 |
+
print(f" Opening(A, SE): {BrailleMorphology.opening(cell_a, se)}")
|
| 967 |
+
print(f" Closing(A, SE): {BrailleMorphology.closing(cell_a, se)}")
|
| 968 |
+
print(f" Gradient(A, SE): {BrailleMorphology.gradient(cell_a, se)}")
|
| 969 |
+
|
| 970 |
+
# Sequence operations
|
| 971 |
+
print(f"\n3. SEQUENCE OPERATIONS")
|
| 972 |
+
text = "Hello"
|
| 973 |
+
seq = BrailleSequence.from_bytes(text.encode())
|
| 974 |
+
print(f" Input text: '{text}'")
|
| 975 |
+
print(f" As Braille: {seq.unicode}")
|
| 976 |
+
print(f" Dilated: {seq.apply_morphology(MorphologicalOperator.DILATION, se).unicode}")
|
| 977 |
+
print(f" Eroded: {seq.apply_morphology(MorphologicalOperator.EROSION, se).unicode}")
|
| 978 |
+
|
| 979 |
+
# Distance metrics
|
| 980 |
+
print(f"\n4. LATTICE METRICS")
|
| 981 |
+
print(f" Hamming(A, B): {BrailleMetrics.hamming_distance(cell_a, cell_b)}")
|
| 982 |
+
print(f" Jaccard(A, B): {BrailleMetrics.jaccard_similarity(cell_a, cell_b):.3f}")
|
| 983 |
+
print(f" Lattice(A, B): {BrailleMetrics.lattice_distance(cell_a, cell_b)}")
|
| 984 |
+
|
| 985 |
+
# Modality-invariant reasoning
|
| 986 |
+
print(f"\n5. MODALITY-INVARIANT REASONING LOOP")
|
| 987 |
+
loop = BrailleReasoningLoop()
|
| 988 |
+
|
| 989 |
+
input_text = b"Test input"
|
| 990 |
+
context = b"Context data"
|
| 991 |
+
|
| 992 |
+
output = loop.full_loop(
|
| 993 |
+
input_data=input_text,
|
| 994 |
+
input_modality=Modality.TEXT,
|
| 995 |
+
context_data=context,
|
| 996 |
+
context_modality=Modality.TEXT
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
print(f" Input: {input_text}")
|
| 1000 |
+
print(f" Context: {context}")
|
| 1001 |
+
print(f" Output: {output}")
|
| 1002 |
+
print(f" (Output differs due to morphological transformations)")
|
| 1003 |
+
|
| 1004 |
+
print("\n" + "=" * 60)
|
| 1005 |
+
print("THEORETICAL FRAMEWORK COMPLETE")
|
| 1006 |
+
print("=" * 60)
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
if __name__ == "__main__":
|
| 1010 |
+
demonstrate_lattice_operations()
|
config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 32000,
|
| 3 |
+
"hidden_size": 256,
|
| 4 |
+
"num_layers": 4,
|
| 5 |
+
"num_heads": 4,
|
| 6 |
+
"intermediate_size": 1024,
|
| 7 |
+
"max_position_embeddings": 512,
|
| 8 |
+
"dropout": 0.1,
|
| 9 |
+
"use_lattice_attention": true,
|
| 10 |
+
"lattice_attention_weight": 0.4,
|
| 11 |
+
"use_morphological_regularization": true,
|
| 12 |
+
"morphological_weight": 0.005000000000000001,
|
| 13 |
+
"use_lattice_embeddings": true,
|
| 14 |
+
"structuring_element": "six_dot",
|
| 15 |
+
"embedding_dropout": 0.15,
|
| 16 |
+
"modality_embedding_dim": 32,
|
| 17 |
+
"num_modalities": 5
|
| 18 |
+
}
|
final_eval.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"lattice_coherence": 0.7428203609814639,
|
| 3 |
+
"morphological_stability": 0.4530333221703768,
|
| 4 |
+
"haptic_score": 0.5979268415759204
|
| 5 |
+
}
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7721b549cb7724d1f110f0221b98044a1472b95ab01ac4a586dd2ae2cbfa0704
|
| 3 |
+
size 47273908
|
tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec5b8b6fbd8985a97c74d377a83f58ff59f3860d02a343eb15146da467da40ae
|
| 3 |
+
size 1155082
|
train_lattice_v6.py
ADDED
|
@@ -0,0 +1,990 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Train braille256-v6: Lattice-Aware Multimodal Braille Model
|
| 4 |
+
|
| 5 |
+
This is the first LLM with explicit dot-lattice structure in its architecture:
|
| 6 |
+
1. Lattice-aware attention (Hamming-based similarity)
|
| 7 |
+
2. Morphological regularization (erosion/dilation as inductive bias)
|
| 8 |
+
3. Lattice-structured embeddings (respecting Boolean algebra)
|
| 9 |
+
4. Modality-invariant reasoning loops
|
| 10 |
+
|
| 11 |
+
Building on v5's multimodal foundation, v6 integrates the formal theory
|
| 12 |
+
from braille_lattice_theory.py into the training pipeline.
|
| 13 |
+
|
| 14 |
+
Author: Ryan Barrett & Cascade
|
| 15 |
+
Date: December 2024
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
import json
|
| 21 |
+
import math
|
| 22 |
+
import logging
|
| 23 |
+
import argparse
|
| 24 |
+
from dataclasses import dataclass, field
|
| 25 |
+
from typing import Optional, List, Tuple, Dict
|
| 26 |
+
from enum import Enum
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
import torch.nn.functional as F
|
| 31 |
+
from torch.utils.data import Dataset, DataLoader
|
| 32 |
+
from tqdm import tqdm
|
| 33 |
+
import numpy as np
|
| 34 |
+
|
| 35 |
+
import sentencepiece as spm
|
| 36 |
+
|
| 37 |
+
# Import lattice theory
|
| 38 |
+
from braille_lattice_theory import (
|
| 39 |
+
BrailleCell, BrailleMorphology, BrailleSequence,
|
| 40 |
+
StructuringElement, MorphologicalOperator, BrailleMetrics,
|
| 41 |
+
BRAILLE_BASE, BRAILLE_MAX
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
logging.basicConfig(level=logging.INFO)
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
# =============================================================================
|
| 48 |
+
# Configuration
|
| 49 |
+
# =============================================================================
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class LatticeConfig:
|
| 53 |
+
"""Configuration for lattice-aware model."""
|
| 54 |
+
# Model architecture
|
| 55 |
+
vocab_size: int = 32000
|
| 56 |
+
hidden_size: int = 256
|
| 57 |
+
num_layers: int = 4
|
| 58 |
+
num_heads: int = 4
|
| 59 |
+
intermediate_size: int = 1024
|
| 60 |
+
max_position_embeddings: int = 512
|
| 61 |
+
dropout: float = 0.1
|
| 62 |
+
|
| 63 |
+
# Lattice-specific settings
|
| 64 |
+
use_lattice_attention: bool = True
|
| 65 |
+
lattice_attention_weight: float = 0.4 # Blend with standard attention (increased)
|
| 66 |
+
use_morphological_regularization: bool = True
|
| 67 |
+
morphological_weight: float = 0.05 # Regularization strength (increased 5x)
|
| 68 |
+
use_lattice_embeddings: bool = True
|
| 69 |
+
structuring_element: str = "six_dot" # Which SE to use
|
| 70 |
+
embedding_dropout: float = 0.15 # Dropout on embeddings to prevent overfitting
|
| 71 |
+
|
| 72 |
+
# Modality settings
|
| 73 |
+
modality_embedding_dim: int = 32
|
| 74 |
+
num_modalities: int = 5 # TEXT, IMAGE, AUDIO, BINARY, VIDEO
|
| 75 |
+
|
| 76 |
+
def to_dict(self):
|
| 77 |
+
return {k: v for k, v in self.__dict__.items()}
|
| 78 |
+
|
| 79 |
+
@classmethod
|
| 80 |
+
def from_dict(cls, d):
|
| 81 |
+
return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# =============================================================================
|
| 85 |
+
# Lattice-Aware Attention
|
| 86 |
+
# =============================================================================
|
| 87 |
+
|
| 88 |
+
class LatticeAttention(nn.Module):
|
| 89 |
+
"""
|
| 90 |
+
Attention mechanism that incorporates Braille lattice structure.
|
| 91 |
+
|
| 92 |
+
Key innovation: Combines standard softmax attention with lattice-based
|
| 93 |
+
similarity computed via Hamming distance on the underlying Braille cells.
|
| 94 |
+
|
| 95 |
+
For tokens that map to Braille cells, we compute:
|
| 96 |
+
lattice_sim(a, b) = 8 - Hamming(a ⊕ b)
|
| 97 |
+
|
| 98 |
+
This is then blended with standard QK^T attention.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(self, config: LatticeConfig):
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.config = config
|
| 104 |
+
self.num_heads = config.num_heads
|
| 105 |
+
self.head_dim = config.hidden_size // config.num_heads
|
| 106 |
+
self.lattice_weight = config.lattice_attention_weight
|
| 107 |
+
|
| 108 |
+
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| 109 |
+
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| 110 |
+
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| 111 |
+
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| 112 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 113 |
+
|
| 114 |
+
# Learnable lattice attention temperature
|
| 115 |
+
self.lattice_temperature = nn.Parameter(torch.ones(1))
|
| 116 |
+
|
| 117 |
+
# Precompute Hamming distance matrix for all 256 Braille cells
|
| 118 |
+
self._precompute_hamming_matrix()
|
| 119 |
+
|
| 120 |
+
def _precompute_hamming_matrix(self):
|
| 121 |
+
"""Precompute pairwise Hamming distances for efficiency."""
|
| 122 |
+
hamming = torch.zeros(256, 256)
|
| 123 |
+
for i in range(256):
|
| 124 |
+
for j in range(256):
|
| 125 |
+
# XOR and count bits
|
| 126 |
+
xor = i ^ j
|
| 127 |
+
hamming[i, j] = bin(xor).count('1')
|
| 128 |
+
|
| 129 |
+
# Convert to similarity: 8 - hamming (range [0, 8])
|
| 130 |
+
self.register_buffer('lattice_similarity', 8 - hamming)
|
| 131 |
+
|
| 132 |
+
def _get_braille_values(self, token_ids: torch.Tensor, sp_model) -> torch.Tensor:
|
| 133 |
+
"""
|
| 134 |
+
Extract Braille cell values from token IDs (vectorized).
|
| 135 |
+
|
| 136 |
+
For tokens that decode to Braille characters, return their cell value.
|
| 137 |
+
For others, return -1 (will be masked in attention).
|
| 138 |
+
"""
|
| 139 |
+
# Vectorized: tokens < 256 are treated as Braille cells
|
| 140 |
+
braille_values = torch.where(
|
| 141 |
+
token_ids < 256,
|
| 142 |
+
token_ids,
|
| 143 |
+
torch.full_like(token_ids, -1)
|
| 144 |
+
)
|
| 145 |
+
return braille_values
|
| 146 |
+
|
| 147 |
+
def compute_lattice_attention(self, braille_values: torch.Tensor) -> torch.Tensor:
|
| 148 |
+
"""
|
| 149 |
+
Compute attention scores based on lattice similarity (fully vectorized).
|
| 150 |
+
|
| 151 |
+
Returns attention logits of shape (B, T, T).
|
| 152 |
+
"""
|
| 153 |
+
B, T = braille_values.shape
|
| 154 |
+
|
| 155 |
+
# Mask for valid Braille values
|
| 156 |
+
valid_mask = (braille_values >= 0).float()
|
| 157 |
+
|
| 158 |
+
# Clamp to valid range for indexing
|
| 159 |
+
safe_values = braille_values.clamp(0, 255).long()
|
| 160 |
+
|
| 161 |
+
# Vectorized lookup: use advanced indexing
|
| 162 |
+
# Flatten batch for indexing, then reshape
|
| 163 |
+
flat_values = safe_values.view(-1) # (B*T,)
|
| 164 |
+
|
| 165 |
+
# Get similarity rows for each token
|
| 166 |
+
# lattice_similarity is (256, 256), we want (B*T, 256)
|
| 167 |
+
sim_rows = self.lattice_similarity[flat_values] # (B*T, 256)
|
| 168 |
+
|
| 169 |
+
# Now index columns: for each pair (i, j), get sim_rows[i, safe_values[j]]
|
| 170 |
+
# Reshape to (B, T, 256) then gather along last dim
|
| 171 |
+
sim_rows = sim_rows.view(B, T, 256) # (B, T, 256)
|
| 172 |
+
|
| 173 |
+
# Expand safe_values for gathering: (B, T) -> (B, T, T)
|
| 174 |
+
indices = safe_values.unsqueeze(1).expand(B, T, T) # (B, T, T)
|
| 175 |
+
|
| 176 |
+
# Gather: for each (b, i, j), get sim_rows[b, i, safe_values[b, j]]
|
| 177 |
+
lattice_attn = torch.gather(sim_rows, 2, indices.transpose(1, 2)).transpose(1, 2)
|
| 178 |
+
|
| 179 |
+
# Apply temperature
|
| 180 |
+
lattice_attn = lattice_attn / (self.lattice_temperature + 1e-6)
|
| 181 |
+
|
| 182 |
+
# Mask invalid positions
|
| 183 |
+
valid_2d = valid_mask.unsqueeze(2) * valid_mask.unsqueeze(1) # (B, T, T)
|
| 184 |
+
lattice_attn = lattice_attn * valid_2d
|
| 185 |
+
|
| 186 |
+
return lattice_attn
|
| 187 |
+
|
| 188 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None,
|
| 189 |
+
token_ids: torch.Tensor = None) -> torch.Tensor:
|
| 190 |
+
B, T, C = x.shape
|
| 191 |
+
|
| 192 |
+
# Standard attention
|
| 193 |
+
q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
|
| 194 |
+
k = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
|
| 195 |
+
v = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
|
| 196 |
+
|
| 197 |
+
# Standard QK^T attention
|
| 198 |
+
standard_attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 199 |
+
|
| 200 |
+
# Lattice attention (if enabled and token_ids provided)
|
| 201 |
+
if self.config.use_lattice_attention and token_ids is not None:
|
| 202 |
+
# Compute lattice-based attention
|
| 203 |
+
braille_values = self._get_braille_values(token_ids, None)
|
| 204 |
+
lattice_attn = self.compute_lattice_attention(braille_values)
|
| 205 |
+
|
| 206 |
+
# Expand for heads: (B, T, T) -> (B, num_heads, T, T)
|
| 207 |
+
lattice_attn = lattice_attn.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
| 208 |
+
|
| 209 |
+
# Blend standard and lattice attention
|
| 210 |
+
attn = (1 - self.lattice_weight) * standard_attn + self.lattice_weight * lattice_attn
|
| 211 |
+
else:
|
| 212 |
+
attn = standard_attn
|
| 213 |
+
|
| 214 |
+
# Apply causal mask
|
| 215 |
+
if mask is not None:
|
| 216 |
+
attn = attn.masked_fill(mask == 0, float('-inf'))
|
| 217 |
+
|
| 218 |
+
attn = F.softmax(attn, dim=-1)
|
| 219 |
+
attn = self.dropout(attn)
|
| 220 |
+
|
| 221 |
+
out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)
|
| 222 |
+
return self.out_proj(out)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# =============================================================================
|
| 226 |
+
# Lattice-Aware Embeddings
|
| 227 |
+
# =============================================================================
|
| 228 |
+
|
| 229 |
+
class LatticeEmbedding(nn.Module):
|
| 230 |
+
"""
|
| 231 |
+
Token embeddings that respect Braille lattice structure.
|
| 232 |
+
|
| 233 |
+
Key insight: Initialize embeddings so that similar Braille cells
|
| 234 |
+
(low Hamming distance) have similar embeddings.
|
| 235 |
+
|
| 236 |
+
This provides an inductive bias that helps the model learn
|
| 237 |
+
patterns in the lattice structure.
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
def __init__(self, config: LatticeConfig):
|
| 241 |
+
super().__init__()
|
| 242 |
+
self.config = config
|
| 243 |
+
|
| 244 |
+
# Standard embedding
|
| 245 |
+
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 246 |
+
|
| 247 |
+
# Lattice structure embedding (for first 256 tokens = Braille cells)
|
| 248 |
+
self.lattice_embedding = nn.Embedding(256, config.hidden_size)
|
| 249 |
+
|
| 250 |
+
# Initialize lattice embeddings with structure
|
| 251 |
+
self._init_lattice_structure()
|
| 252 |
+
|
| 253 |
+
# Learnable blend weight
|
| 254 |
+
self.lattice_blend = nn.Parameter(torch.tensor(0.1))
|
| 255 |
+
|
| 256 |
+
def _init_lattice_structure(self):
|
| 257 |
+
"""Initialize embeddings to reflect lattice structure."""
|
| 258 |
+
with torch.no_grad():
|
| 259 |
+
# Each Braille cell is an 8-bit vector
|
| 260 |
+
# Map each bit to a learned direction in embedding space
|
| 261 |
+
|
| 262 |
+
# Create 8 basis vectors (one per dot)
|
| 263 |
+
basis = torch.randn(8, self.config.hidden_size) * 0.1
|
| 264 |
+
|
| 265 |
+
for i in range(256):
|
| 266 |
+
# Get the bits of this cell
|
| 267 |
+
bits = [(i >> b) & 1 for b in range(8)]
|
| 268 |
+
|
| 269 |
+
# Embedding is sum of basis vectors for raised dots
|
| 270 |
+
emb = torch.zeros(self.config.hidden_size)
|
| 271 |
+
for b, bit in enumerate(bits):
|
| 272 |
+
if bit:
|
| 273 |
+
emb += basis[b]
|
| 274 |
+
|
| 275 |
+
self.lattice_embedding.weight[i] = emb
|
| 276 |
+
|
| 277 |
+
def forward(self, token_ids: torch.Tensor, training: bool = True) -> torch.Tensor:
|
| 278 |
+
# Standard embedding
|
| 279 |
+
std_emb = self.embedding(token_ids)
|
| 280 |
+
|
| 281 |
+
if self.config.use_lattice_embeddings:
|
| 282 |
+
# For tokens < 256, blend with lattice embedding
|
| 283 |
+
mask = (token_ids < 256).float().unsqueeze(-1)
|
| 284 |
+
safe_ids = token_ids.clamp(0, 255)
|
| 285 |
+
lat_emb = self.lattice_embedding(safe_ids)
|
| 286 |
+
|
| 287 |
+
# Blend: standard + lattice_blend * lattice (for Braille tokens only)
|
| 288 |
+
std_emb = std_emb + mask * self.lattice_blend * lat_emb
|
| 289 |
+
|
| 290 |
+
# Apply embedding dropout during training to prevent overfitting
|
| 291 |
+
if training and self.config.embedding_dropout > 0:
|
| 292 |
+
std_emb = F.dropout(std_emb, p=self.config.embedding_dropout, training=True)
|
| 293 |
+
|
| 294 |
+
return std_emb
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# =============================================================================
|
| 298 |
+
# Morphological Regularization
|
| 299 |
+
# =============================================================================
|
| 300 |
+
|
| 301 |
+
class MorphologicalRegularizer(nn.Module):
|
| 302 |
+
"""
|
| 303 |
+
Regularization based on morphological operations.
|
| 304 |
+
|
| 305 |
+
Encourages the model to learn representations that are
|
| 306 |
+
consistent under morphological transformations (erosion, dilation).
|
| 307 |
+
|
| 308 |
+
Loss = ||f(erode(x)) - erode(f(x))||² + ||f(dilate(x)) - dilate(f(x))||²
|
| 309 |
+
|
| 310 |
+
This is a form of equivariance regularization.
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
def __init__(self, config: LatticeConfig):
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.config = config
|
| 316 |
+
|
| 317 |
+
# Get structuring element
|
| 318 |
+
se_map = {
|
| 319 |
+
'six_dot': StructuringElement.six_dot(),
|
| 320 |
+
'column_left': StructuringElement.column_left(),
|
| 321 |
+
'column_right': StructuringElement.column_right(),
|
| 322 |
+
'full': StructuringElement.full(),
|
| 323 |
+
}
|
| 324 |
+
self.se = se_map.get(config.structuring_element, StructuringElement.six_dot())
|
| 325 |
+
self.se_value = self.se.cell.value
|
| 326 |
+
|
| 327 |
+
def apply_morphology(self, token_ids: torch.Tensor,
|
| 328 |
+
op: str = 'erode') -> torch.Tensor:
|
| 329 |
+
"""Apply morphological operation to token IDs."""
|
| 330 |
+
result = token_ids.clone()
|
| 331 |
+
mask = token_ids < 256 # Only apply to Braille tokens
|
| 332 |
+
|
| 333 |
+
if op == 'erode':
|
| 334 |
+
# Erosion: AND with structuring element
|
| 335 |
+
result[mask] = token_ids[mask] & self.se_value
|
| 336 |
+
elif op == 'dilate':
|
| 337 |
+
# Dilation: OR with structuring element
|
| 338 |
+
result[mask] = (token_ids[mask] | self.se_value) & 0xFF
|
| 339 |
+
|
| 340 |
+
return result
|
| 341 |
+
|
| 342 |
+
def compute_loss(self, embeddings: torch.Tensor,
|
| 343 |
+
token_ids: torch.Tensor,
|
| 344 |
+
embedding_layer: nn.Module) -> torch.Tensor:
|
| 345 |
+
"""
|
| 346 |
+
Compute morphological equivariance loss.
|
| 347 |
+
|
| 348 |
+
We want: embed(morph(x)) ≈ morph_embed(embed(x))
|
| 349 |
+
|
| 350 |
+
Since we can't directly apply morphology to embeddings,
|
| 351 |
+
we use a proxy: embeddings of morphologically related tokens
|
| 352 |
+
should be similar.
|
| 353 |
+
"""
|
| 354 |
+
if not self.config.use_morphological_regularization:
|
| 355 |
+
return torch.tensor(0.0, device=embeddings.device)
|
| 356 |
+
|
| 357 |
+
# Get eroded and dilated token IDs
|
| 358 |
+
eroded_ids = self.apply_morphology(token_ids, 'erode')
|
| 359 |
+
dilated_ids = self.apply_morphology(token_ids, 'dilate')
|
| 360 |
+
|
| 361 |
+
# Get embeddings
|
| 362 |
+
eroded_emb = embedding_layer(eroded_ids)
|
| 363 |
+
dilated_emb = embedding_layer(dilated_ids)
|
| 364 |
+
|
| 365 |
+
# Regularization: encourage margin between eroded and dilated distances
|
| 366 |
+
# Always-on version: penalize deviation from ideal ordering
|
| 367 |
+
# Ideal: dist_to_eroded < dist_to_original < dist_to_dilated
|
| 368 |
+
|
| 369 |
+
dist_to_eroded = F.mse_loss(embeddings, eroded_emb)
|
| 370 |
+
dist_to_dilated = F.mse_loss(embeddings, dilated_emb)
|
| 371 |
+
|
| 372 |
+
# Always-on: encourage margin (eroded should be closer than dilated)
|
| 373 |
+
# Use squared difference for smooth gradient
|
| 374 |
+
margin_loss = (dist_to_eroded - dist_to_dilated + 0.1).pow(2)
|
| 375 |
+
|
| 376 |
+
# Also add coherence loss: embeddings should be close to their morphological neighbors
|
| 377 |
+
coherence_loss = dist_to_eroded + dist_to_dilated
|
| 378 |
+
|
| 379 |
+
loss = margin_loss + 0.1 * coherence_loss
|
| 380 |
+
|
| 381 |
+
return loss * self.config.morphological_weight
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
# =============================================================================
|
| 385 |
+
# Modality Embedding
|
| 386 |
+
# =============================================================================
|
| 387 |
+
|
| 388 |
+
class ModalityEmbedding(nn.Module):
|
| 389 |
+
"""
|
| 390 |
+
Embeddings for different modalities.
|
| 391 |
+
|
| 392 |
+
Adds a learned embedding based on the detected modality
|
| 393 |
+
of each token sequence.
|
| 394 |
+
"""
|
| 395 |
+
|
| 396 |
+
# Modality header tokens (from v5)
|
| 397 |
+
MODALITY_HEADERS = {
|
| 398 |
+
'TEXT': (0xFF, 0x01), # ⣿⠁
|
| 399 |
+
'IMAGE': (0xFF, 0x03), # ⣿⠃
|
| 400 |
+
'AUDIO': (0xFF, 0x07), # ⣿⠇
|
| 401 |
+
'BINARY': (0xFF, 0x0F), # ⣿⠏
|
| 402 |
+
'VIDEO': (0xFF, 0x17), # ⣿⠗
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
def __init__(self, config: LatticeConfig):
|
| 406 |
+
super().__init__()
|
| 407 |
+
self.embedding = nn.Embedding(config.num_modalities, config.hidden_size)
|
| 408 |
+
|
| 409 |
+
def detect_modality(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 410 |
+
"""Detect modality from token sequence (simplified)."""
|
| 411 |
+
# For now, return 0 (TEXT) for all - would need tokenizer to decode
|
| 412 |
+
return torch.zeros(token_ids.shape[0], dtype=torch.long, device=token_ids.device)
|
| 413 |
+
|
| 414 |
+
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 415 |
+
modality_ids = self.detect_modality(token_ids)
|
| 416 |
+
return self.embedding(modality_ids).unsqueeze(1) # (B, 1, H)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
# =============================================================================
|
| 420 |
+
# Full Model
|
| 421 |
+
# =============================================================================
|
| 422 |
+
|
| 423 |
+
class FeedForward(nn.Module):
|
| 424 |
+
def __init__(self, config: LatticeConfig):
|
| 425 |
+
super().__init__()
|
| 426 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 427 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 428 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 429 |
+
|
| 430 |
+
def forward(self, x):
|
| 431 |
+
return self.fc2(self.dropout(F.gelu(self.fc1(x))))
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class LatticeTransformerBlock(nn.Module):
|
| 435 |
+
"""Transformer block with lattice-aware attention."""
|
| 436 |
+
|
| 437 |
+
def __init__(self, config: LatticeConfig):
|
| 438 |
+
super().__init__()
|
| 439 |
+
self.ln1 = nn.LayerNorm(config.hidden_size)
|
| 440 |
+
self.attn = LatticeAttention(config)
|
| 441 |
+
self.ln2 = nn.LayerNorm(config.hidden_size)
|
| 442 |
+
self.ff = FeedForward(config)
|
| 443 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 444 |
+
|
| 445 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None,
|
| 446 |
+
token_ids: torch.Tensor = None) -> torch.Tensor:
|
| 447 |
+
x = x + self.dropout(self.attn(self.ln1(x), mask, token_ids))
|
| 448 |
+
x = x + self.dropout(self.ff(self.ln2(x)))
|
| 449 |
+
return x
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
class Braille256LatticeModel(nn.Module):
|
| 453 |
+
"""
|
| 454 |
+
braille256-v6: Lattice-Aware Multimodal Braille Model
|
| 455 |
+
|
| 456 |
+
Key innovations over v5:
|
| 457 |
+
1. LatticeAttention: Hamming-based similarity in attention
|
| 458 |
+
2. LatticeEmbedding: Structure-aware token embeddings
|
| 459 |
+
3. MorphologicalRegularizer: Equivariance regularization
|
| 460 |
+
4. ModalityEmbedding: Explicit modality awareness
|
| 461 |
+
"""
|
| 462 |
+
|
| 463 |
+
def __init__(self, config: LatticeConfig):
|
| 464 |
+
super().__init__()
|
| 465 |
+
self.config = config
|
| 466 |
+
|
| 467 |
+
# Lattice-aware embeddings
|
| 468 |
+
self.token_embedding = LatticeEmbedding(config)
|
| 469 |
+
self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 470 |
+
self.modality_embedding = ModalityEmbedding(config)
|
| 471 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 472 |
+
|
| 473 |
+
# Transformer layers with lattice attention
|
| 474 |
+
self.layers = nn.ModuleList([
|
| 475 |
+
LatticeTransformerBlock(config) for _ in range(config.num_layers)
|
| 476 |
+
])
|
| 477 |
+
|
| 478 |
+
self.ln_f = nn.LayerNorm(config.hidden_size)
|
| 479 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 480 |
+
|
| 481 |
+
# Morphological regularizer
|
| 482 |
+
self.morph_regularizer = MorphologicalRegularizer(config)
|
| 483 |
+
|
| 484 |
+
# Weight tying
|
| 485 |
+
self.lm_head.weight = self.token_embedding.embedding.weight
|
| 486 |
+
|
| 487 |
+
self.apply(self._init_weights)
|
| 488 |
+
|
| 489 |
+
# Log architecture
|
| 490 |
+
total_params = sum(p.numel() for p in self.parameters())
|
| 491 |
+
logger.info(f"Braille256-v6 Lattice Model: {total_params:,} parameters")
|
| 492 |
+
logger.info(f" Lattice attention: {config.use_lattice_attention}")
|
| 493 |
+
logger.info(f" Lattice embeddings: {config.use_lattice_embeddings}")
|
| 494 |
+
logger.info(f" Morphological regularization: {config.use_morphological_regularization}")
|
| 495 |
+
|
| 496 |
+
def _init_weights(self, module):
|
| 497 |
+
if isinstance(module, nn.Linear):
|
| 498 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 499 |
+
if module.bias is not None:
|
| 500 |
+
torch.nn.init.zeros_(module.bias)
|
| 501 |
+
elif isinstance(module, nn.Embedding):
|
| 502 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 503 |
+
|
| 504 |
+
def forward(self, input_ids: torch.Tensor,
|
| 505 |
+
labels: torch.Tensor = None) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 506 |
+
B, T = input_ids.shape
|
| 507 |
+
|
| 508 |
+
# Embeddings
|
| 509 |
+
positions = torch.arange(T, device=input_ids.device).unsqueeze(0)
|
| 510 |
+
tok_emb = self.token_embedding(input_ids, training=self.training)
|
| 511 |
+
pos_emb = self.position_embedding(positions)
|
| 512 |
+
mod_emb = self.modality_embedding(input_ids)
|
| 513 |
+
|
| 514 |
+
x = tok_emb + pos_emb + mod_emb
|
| 515 |
+
x = self.dropout(x)
|
| 516 |
+
|
| 517 |
+
# Causal mask
|
| 518 |
+
mask = torch.tril(torch.ones(T, T, device=input_ids.device)).unsqueeze(0).unsqueeze(0)
|
| 519 |
+
|
| 520 |
+
# Transformer layers
|
| 521 |
+
for layer in self.layers:
|
| 522 |
+
x = layer(x, mask, input_ids)
|
| 523 |
+
|
| 524 |
+
x = self.ln_f(x)
|
| 525 |
+
logits = self.lm_head(x)
|
| 526 |
+
|
| 527 |
+
# Compute losses
|
| 528 |
+
lm_loss = None
|
| 529 |
+
morph_loss = None
|
| 530 |
+
|
| 531 |
+
if labels is not None:
|
| 532 |
+
lm_loss = F.cross_entropy(
|
| 533 |
+
logits.view(-1, self.config.vocab_size),
|
| 534 |
+
labels.view(-1),
|
| 535 |
+
ignore_index=-100
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
# Morphological regularization
|
| 539 |
+
morph_loss = self.morph_regularizer.compute_loss(
|
| 540 |
+
tok_emb, input_ids, self.token_embedding
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
return logits, lm_loss, morph_loss
|
| 544 |
+
|
| 545 |
+
def generate(self, input_ids: torch.Tensor, max_length: int = 100,
|
| 546 |
+
temperature: float = 1.0, top_k: int = 50) -> torch.Tensor:
|
| 547 |
+
self.eval()
|
| 548 |
+
with torch.no_grad():
|
| 549 |
+
for _ in range(max_length):
|
| 550 |
+
if input_ids.shape[1] >= self.config.max_position_embeddings:
|
| 551 |
+
break
|
| 552 |
+
|
| 553 |
+
logits, _, _ = self(input_ids)
|
| 554 |
+
logits = logits[:, -1, :] / temperature
|
| 555 |
+
|
| 556 |
+
if top_k > 0:
|
| 557 |
+
v, _ = torch.topk(logits, top_k)
|
| 558 |
+
logits[logits < v[:, [-1]]] = float('-inf')
|
| 559 |
+
|
| 560 |
+
probs = F.softmax(logits, dim=-1)
|
| 561 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 562 |
+
input_ids = torch.cat([input_ids, next_token], dim=1)
|
| 563 |
+
|
| 564 |
+
return input_ids
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
# =============================================================================
|
| 568 |
+
# Dataset (same as v5)
|
| 569 |
+
# =============================================================================
|
| 570 |
+
|
| 571 |
+
class MultimodalBrailleDataset(Dataset):
|
| 572 |
+
def __init__(self, corpus_path: str, tokenizer_path: str,
|
| 573 |
+
max_length: int = 512, max_tokens: int = 10_000_000):
|
| 574 |
+
self.max_length = max_length
|
| 575 |
+
|
| 576 |
+
self.sp = spm.SentencePieceProcessor()
|
| 577 |
+
self.sp.load(tokenizer_path)
|
| 578 |
+
|
| 579 |
+
logger.info(f"Loading corpus from {corpus_path}...")
|
| 580 |
+
with open(corpus_path, 'r', encoding='utf-8') as f:
|
| 581 |
+
text = f.read()
|
| 582 |
+
|
| 583 |
+
if len(text) > max_tokens * 3:
|
| 584 |
+
logger.info(f"Limiting corpus from {len(text):,} to ~{max_tokens:,} tokens worth")
|
| 585 |
+
text = text[:max_tokens * 3]
|
| 586 |
+
|
| 587 |
+
logger.info(f"Tokenizing {len(text):,} characters...")
|
| 588 |
+
self.tokens = self.sp.encode(text)
|
| 589 |
+
if len(self.tokens) > max_tokens:
|
| 590 |
+
self.tokens = self.tokens[:max_tokens]
|
| 591 |
+
logger.info(f"Got {len(self.tokens):,} tokens")
|
| 592 |
+
|
| 593 |
+
self.examples = []
|
| 594 |
+
stride = max_length // 2
|
| 595 |
+
for i in range(0, len(self.tokens) - max_length, stride):
|
| 596 |
+
self.examples.append(i)
|
| 597 |
+
|
| 598 |
+
logger.info(f"Created {len(self.examples):,} training examples")
|
| 599 |
+
|
| 600 |
+
def __len__(self):
|
| 601 |
+
return len(self.examples)
|
| 602 |
+
|
| 603 |
+
def __getitem__(self, idx):
|
| 604 |
+
start = self.examples[idx]
|
| 605 |
+
tokens = self.tokens[start:start + self.max_length + 1]
|
| 606 |
+
|
| 607 |
+
input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
|
| 608 |
+
labels = torch.tensor(tokens[1:], dtype=torch.long)
|
| 609 |
+
|
| 610 |
+
return input_ids, labels
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
# =============================================================================
|
| 614 |
+
# Haptic Evaluation
|
| 615 |
+
# =============================================================================
|
| 616 |
+
|
| 617 |
+
class HapticEvaluator:
|
| 618 |
+
"""
|
| 619 |
+
Evaluate model outputs for haptic/tactile quality.
|
| 620 |
+
|
| 621 |
+
Metrics:
|
| 622 |
+
1. Lattice coherence: How well outputs respect lattice structure
|
| 623 |
+
2. Morphological stability: Consistency under erosion/dilation
|
| 624 |
+
3. Modality preservation: Cross-modal semantic consistency
|
| 625 |
+
"""
|
| 626 |
+
|
| 627 |
+
def __init__(self, config: LatticeConfig):
|
| 628 |
+
self.config = config
|
| 629 |
+
self.se = StructuringElement.six_dot()
|
| 630 |
+
|
| 631 |
+
def lattice_coherence(self, token_ids: torch.Tensor) -> float:
|
| 632 |
+
"""
|
| 633 |
+
Measure how well token sequences respect lattice structure.
|
| 634 |
+
|
| 635 |
+
High coherence = adjacent tokens have low Hamming distance.
|
| 636 |
+
"""
|
| 637 |
+
if token_ids.shape[-1] < 2:
|
| 638 |
+
return 1.0
|
| 639 |
+
|
| 640 |
+
total_dist = 0
|
| 641 |
+
count = 0
|
| 642 |
+
|
| 643 |
+
for i in range(token_ids.shape[-1] - 1):
|
| 644 |
+
t1 = token_ids[..., i].item() if token_ids[..., i].numel() == 1 else token_ids[0, i].item()
|
| 645 |
+
t2 = token_ids[..., i+1].item() if token_ids[..., i+1].numel() == 1 else token_ids[0, i+1].item()
|
| 646 |
+
|
| 647 |
+
if t1 < 256 and t2 < 256:
|
| 648 |
+
# Hamming distance
|
| 649 |
+
dist = bin(t1 ^ t2).count('1')
|
| 650 |
+
total_dist += dist
|
| 651 |
+
count += 1
|
| 652 |
+
|
| 653 |
+
if count == 0:
|
| 654 |
+
return 1.0
|
| 655 |
+
|
| 656 |
+
# Normalize: 0 = max coherence, 8 = min coherence
|
| 657 |
+
avg_dist = total_dist / count
|
| 658 |
+
return 1.0 - (avg_dist / 8.0)
|
| 659 |
+
|
| 660 |
+
def morphological_stability(self, token_ids: torch.Tensor) -> float:
|
| 661 |
+
"""
|
| 662 |
+
Measure stability under morphological operations.
|
| 663 |
+
|
| 664 |
+
High stability = erosion and dilation don't change meaning drastically.
|
| 665 |
+
"""
|
| 666 |
+
if token_ids.numel() == 0:
|
| 667 |
+
return 1.0
|
| 668 |
+
|
| 669 |
+
original = token_ids.clone()
|
| 670 |
+
|
| 671 |
+
# Apply erosion
|
| 672 |
+
eroded = original.clone()
|
| 673 |
+
mask = original < 256
|
| 674 |
+
eroded[mask] = original[mask] & self.se.cell.value
|
| 675 |
+
|
| 676 |
+
# Apply dilation
|
| 677 |
+
dilated = original.clone()
|
| 678 |
+
dilated[mask] = (original[mask] | self.se.cell.value) & 0xFF
|
| 679 |
+
|
| 680 |
+
# Measure how much changed
|
| 681 |
+
erode_change = (original[mask] != eroded[mask]).float().mean().item() if mask.any() else 0
|
| 682 |
+
dilate_change = (original[mask] != dilated[mask]).float().mean().item() if mask.any() else 0
|
| 683 |
+
|
| 684 |
+
# Stability = 1 - average change
|
| 685 |
+
return 1.0 - (erode_change + dilate_change) / 2
|
| 686 |
+
|
| 687 |
+
def evaluate(self, model: nn.Module, dataloader: DataLoader,
|
| 688 |
+
device: torch.device, num_samples: int = 100) -> Dict[str, float]:
|
| 689 |
+
"""Run full haptic evaluation."""
|
| 690 |
+
model.eval()
|
| 691 |
+
|
| 692 |
+
coherence_scores = []
|
| 693 |
+
stability_scores = []
|
| 694 |
+
|
| 695 |
+
with torch.no_grad():
|
| 696 |
+
for i, (input_ids, _) in enumerate(dataloader):
|
| 697 |
+
if i >= num_samples:
|
| 698 |
+
break
|
| 699 |
+
|
| 700 |
+
input_ids = input_ids.to(device)
|
| 701 |
+
|
| 702 |
+
# Generate some tokens
|
| 703 |
+
generated = model.generate(input_ids[:, :10], max_length=50)
|
| 704 |
+
|
| 705 |
+
coherence_scores.append(self.lattice_coherence(generated))
|
| 706 |
+
stability_scores.append(self.morphological_stability(generated))
|
| 707 |
+
|
| 708 |
+
return {
|
| 709 |
+
'lattice_coherence': np.mean(coherence_scores),
|
| 710 |
+
'morphological_stability': np.mean(stability_scores),
|
| 711 |
+
'haptic_score': np.mean(coherence_scores) * 0.5 + np.mean(stability_scores) * 0.5
|
| 712 |
+
}
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
# =============================================================================
|
| 716 |
+
# Training
|
| 717 |
+
# =============================================================================
|
| 718 |
+
|
| 719 |
+
def train(
|
| 720 |
+
corpus_path: str,
|
| 721 |
+
tokenizer_path: str,
|
| 722 |
+
output_dir: str,
|
| 723 |
+
max_steps: int = 10000,
|
| 724 |
+
batch_size: int = 16,
|
| 725 |
+
learning_rate: float = 3e-4,
|
| 726 |
+
gradient_accumulation: int = 2,
|
| 727 |
+
save_steps: int = 1000,
|
| 728 |
+
eval_steps: int = 500,
|
| 729 |
+
use_lattice_attention: bool = True,
|
| 730 |
+
use_lattice_embeddings: bool = True,
|
| 731 |
+
use_morphological_regularization: bool = True,
|
| 732 |
+
):
|
| 733 |
+
"""Train the lattice-aware model."""
|
| 734 |
+
|
| 735 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 736 |
+
|
| 737 |
+
# Device
|
| 738 |
+
if torch.backends.mps.is_available():
|
| 739 |
+
device = torch.device("mps")
|
| 740 |
+
elif torch.cuda.is_available():
|
| 741 |
+
device = torch.device("cuda")
|
| 742 |
+
else:
|
| 743 |
+
device = torch.device("cpu")
|
| 744 |
+
logger.info(f"Using device: {device}")
|
| 745 |
+
|
| 746 |
+
# Load tokenizer
|
| 747 |
+
sp = spm.SentencePieceProcessor()
|
| 748 |
+
sp.load(tokenizer_path)
|
| 749 |
+
vocab_size = sp.get_piece_size()
|
| 750 |
+
|
| 751 |
+
# Config
|
| 752 |
+
config = LatticeConfig(
|
| 753 |
+
vocab_size=vocab_size,
|
| 754 |
+
use_lattice_attention=use_lattice_attention,
|
| 755 |
+
use_lattice_embeddings=use_lattice_embeddings,
|
| 756 |
+
use_morphological_regularization=use_morphological_regularization,
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
# Save config
|
| 760 |
+
with open(os.path.join(output_dir, "config.json"), 'w') as f:
|
| 761 |
+
json.dump(config.to_dict(), f, indent=2)
|
| 762 |
+
|
| 763 |
+
# Model
|
| 764 |
+
model = Braille256LatticeModel(config)
|
| 765 |
+
model.to(device)
|
| 766 |
+
|
| 767 |
+
# Dataset
|
| 768 |
+
dataset = MultimodalBrailleDataset(
|
| 769 |
+
corpus_path, tokenizer_path,
|
| 770 |
+
max_length=256, max_tokens=2_000_000
|
| 771 |
+
)
|
| 772 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
|
| 773 |
+
|
| 774 |
+
# Evaluator
|
| 775 |
+
evaluator = HapticEvaluator(config)
|
| 776 |
+
|
| 777 |
+
# Optimizer with increased weight decay to preserve lattice structure
|
| 778 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.05)
|
| 779 |
+
|
| 780 |
+
# LR scheduler
|
| 781 |
+
def lr_lambda(step):
|
| 782 |
+
warmup_steps = 500
|
| 783 |
+
if step < warmup_steps:
|
| 784 |
+
return step / warmup_steps
|
| 785 |
+
decay_steps = max_steps - warmup_steps
|
| 786 |
+
progress = (step - warmup_steps) / decay_steps
|
| 787 |
+
return 0.5 * (1 + math.cos(math.pi * progress))
|
| 788 |
+
|
| 789 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 790 |
+
|
| 791 |
+
# Mixed precision - only for CUDA, MPS AMP has issues with custom ops
|
| 792 |
+
use_amp = device.type == 'cuda'
|
| 793 |
+
scaler = torch.amp.GradScaler('cuda') if use_amp else None
|
| 794 |
+
|
| 795 |
+
# torch.compile disabled for MPS - causes slow compilation overhead
|
| 796 |
+
# Enable only for CUDA
|
| 797 |
+
compiled = False
|
| 798 |
+
if device.type == 'cuda':
|
| 799 |
+
try:
|
| 800 |
+
model = torch.compile(model, mode="reduce-overhead")
|
| 801 |
+
compiled = True
|
| 802 |
+
except Exception as e:
|
| 803 |
+
logger.warning(f"torch.compile not available: {e}")
|
| 804 |
+
|
| 805 |
+
# Training loop
|
| 806 |
+
print("\n" + "=" * 70)
|
| 807 |
+
print("⣿ braille256-v6: Lattice-Aware Training ⣿")
|
| 808 |
+
print("=" * 70)
|
| 809 |
+
print(f" Max steps: {max_steps}")
|
| 810 |
+
print(f" Batch size: {batch_size} x {gradient_accumulation} = {batch_size * gradient_accumulation}")
|
| 811 |
+
print(f" Learning rate: {learning_rate}")
|
| 812 |
+
print(f" Lattice attention: {use_lattice_attention}")
|
| 813 |
+
print(f" Lattice embeddings: {use_lattice_embeddings}")
|
| 814 |
+
print(f" Morphological regularization: {use_morphological_regularization}")
|
| 815 |
+
print(f" Mixed precision (AMP): {use_amp}")
|
| 816 |
+
print(f" torch.compile: {compiled}")
|
| 817 |
+
print(f" Output: {output_dir}")
|
| 818 |
+
print("=" * 70 + "\n")
|
| 819 |
+
|
| 820 |
+
model.train()
|
| 821 |
+
step = 0
|
| 822 |
+
data_iter = iter(dataloader)
|
| 823 |
+
best_haptic_score = 0
|
| 824 |
+
|
| 825 |
+
pbar = tqdm(total=max_steps, desc="Training")
|
| 826 |
+
|
| 827 |
+
training_log = []
|
| 828 |
+
|
| 829 |
+
while step < max_steps:
|
| 830 |
+
optimizer.zero_grad()
|
| 831 |
+
total_lm_loss = 0
|
| 832 |
+
total_morph_loss = 0
|
| 833 |
+
|
| 834 |
+
# Staged morphological regularization: high early, decay later
|
| 835 |
+
# This locks in geometry early while allowing expressivity later
|
| 836 |
+
if step < 1500:
|
| 837 |
+
morph_weight_scale = 1.0 # Full strength: 0.05
|
| 838 |
+
elif step < 4000:
|
| 839 |
+
morph_weight_scale = 0.4 # Medium: 0.02
|
| 840 |
+
else:
|
| 841 |
+
morph_weight_scale = 0.1 # Low: 0.005
|
| 842 |
+
|
| 843 |
+
# Update the model's morph weight dynamically
|
| 844 |
+
model.morph_regularizer.config.morphological_weight = 0.05 * morph_weight_scale
|
| 845 |
+
|
| 846 |
+
for _ in range(gradient_accumulation):
|
| 847 |
+
try:
|
| 848 |
+
input_ids, labels = next(data_iter)
|
| 849 |
+
except StopIteration:
|
| 850 |
+
data_iter = iter(dataloader)
|
| 851 |
+
input_ids, labels = next(data_iter)
|
| 852 |
+
|
| 853 |
+
input_ids = input_ids.to(device)
|
| 854 |
+
labels = labels.to(device)
|
| 855 |
+
|
| 856 |
+
# Mixed precision forward pass
|
| 857 |
+
if use_amp:
|
| 858 |
+
with torch.amp.autocast(device.type):
|
| 859 |
+
_, lm_loss, morph_loss = model(input_ids, labels)
|
| 860 |
+
loss = lm_loss + morph_loss
|
| 861 |
+
loss = loss / gradient_accumulation
|
| 862 |
+
scaler.scale(loss).backward()
|
| 863 |
+
else:
|
| 864 |
+
_, lm_loss, morph_loss = model(input_ids, labels)
|
| 865 |
+
loss = lm_loss + morph_loss
|
| 866 |
+
loss = loss / gradient_accumulation
|
| 867 |
+
loss.backward()
|
| 868 |
+
|
| 869 |
+
total_lm_loss += lm_loss.item() / gradient_accumulation
|
| 870 |
+
total_morph_loss += morph_loss.item() / gradient_accumulation if morph_loss else 0
|
| 871 |
+
|
| 872 |
+
if use_amp:
|
| 873 |
+
scaler.unscale_(optimizer)
|
| 874 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 875 |
+
scaler.step(optimizer)
|
| 876 |
+
scaler.update()
|
| 877 |
+
else:
|
| 878 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 879 |
+
optimizer.step()
|
| 880 |
+
scheduler.step()
|
| 881 |
+
|
| 882 |
+
step += 1
|
| 883 |
+
|
| 884 |
+
pbar.set_postfix(
|
| 885 |
+
lm_loss=f"{total_lm_loss:.4f}",
|
| 886 |
+
morph=f"{total_morph_loss:.4f}",
|
| 887 |
+
lr=f"{scheduler.get_last_lr()[0]:.2e}"
|
| 888 |
+
)
|
| 889 |
+
pbar.update(1)
|
| 890 |
+
|
| 891 |
+
# Log
|
| 892 |
+
if step % 100 == 0:
|
| 893 |
+
training_log.append({
|
| 894 |
+
'step': step,
|
| 895 |
+
'lm_loss': total_lm_loss,
|
| 896 |
+
'morph_loss': total_morph_loss,
|
| 897 |
+
'lr': scheduler.get_last_lr()[0]
|
| 898 |
+
})
|
| 899 |
+
|
| 900 |
+
# Evaluate
|
| 901 |
+
if step % eval_steps == 0:
|
| 902 |
+
eval_results = evaluator.evaluate(model, dataloader, device, num_samples=20)
|
| 903 |
+
logger.info(f"\nStep {step} Haptic Eval: {eval_results}")
|
| 904 |
+
|
| 905 |
+
if eval_results['haptic_score'] > best_haptic_score:
|
| 906 |
+
best_haptic_score = eval_results['haptic_score']
|
| 907 |
+
# Save best model
|
| 908 |
+
best_dir = os.path.join(output_dir, "best")
|
| 909 |
+
os.makedirs(best_dir, exist_ok=True)
|
| 910 |
+
torch.save(model.state_dict(), os.path.join(best_dir, "pytorch_model.bin"))
|
| 911 |
+
logger.info(f"New best haptic score: {best_haptic_score:.4f}")
|
| 912 |
+
|
| 913 |
+
model.train()
|
| 914 |
+
|
| 915 |
+
# Save checkpoint
|
| 916 |
+
if step % save_steps == 0:
|
| 917 |
+
checkpoint_dir = os.path.join(output_dir, f"checkpoint-{step}")
|
| 918 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 919 |
+
torch.save(model.state_dict(), os.path.join(checkpoint_dir, "pytorch_model.bin"))
|
| 920 |
+
with open(os.path.join(checkpoint_dir, "config.json"), 'w') as f:
|
| 921 |
+
json.dump(config.to_dict(), f, indent=2)
|
| 922 |
+
logger.info(f"Saved checkpoint at step {step}")
|
| 923 |
+
|
| 924 |
+
pbar.close()
|
| 925 |
+
|
| 926 |
+
# Save final model
|
| 927 |
+
print("\n" + "=" * 70)
|
| 928 |
+
print("Saving Final Model")
|
| 929 |
+
print("=" * 70)
|
| 930 |
+
|
| 931 |
+
final_dir = os.path.join(output_dir, "final")
|
| 932 |
+
os.makedirs(final_dir, exist_ok=True)
|
| 933 |
+
|
| 934 |
+
torch.save(model.state_dict(), os.path.join(final_dir, "pytorch_model.bin"))
|
| 935 |
+
with open(os.path.join(final_dir, "config.json"), 'w') as f:
|
| 936 |
+
json.dump(config.to_dict(), f, indent=2)
|
| 937 |
+
|
| 938 |
+
# Save training log
|
| 939 |
+
with open(os.path.join(output_dir, "training_log.json"), 'w') as f:
|
| 940 |
+
json.dump(training_log, f, indent=2)
|
| 941 |
+
|
| 942 |
+
# Copy tokenizer
|
| 943 |
+
import shutil
|
| 944 |
+
shutil.copy(tokenizer_path, os.path.join(final_dir, "tokenizer.model"))
|
| 945 |
+
|
| 946 |
+
# Final evaluation
|
| 947 |
+
final_eval = evaluator.evaluate(model, dataloader, device, num_samples=50)
|
| 948 |
+
print(f"\nFinal Haptic Evaluation:")
|
| 949 |
+
print(f" Lattice Coherence: {final_eval['lattice_coherence']:.4f}")
|
| 950 |
+
print(f" Morphological Stability: {final_eval['morphological_stability']:.4f}")
|
| 951 |
+
print(f" Haptic Score: {final_eval['haptic_score']:.4f}")
|
| 952 |
+
|
| 953 |
+
with open(os.path.join(output_dir, "final_eval.json"), 'w') as f:
|
| 954 |
+
json.dump(final_eval, f, indent=2)
|
| 955 |
+
|
| 956 |
+
print(f"\nModel saved to: {final_dir}")
|
| 957 |
+
print("\n" + "=" * 70)
|
| 958 |
+
print("⣿ Training Complete! ⣿")
|
| 959 |
+
print("=" * 70)
|
| 960 |
+
|
| 961 |
+
|
| 962 |
+
def main():
|
| 963 |
+
parser = argparse.ArgumentParser(description="Train braille256-v6 lattice-aware model")
|
| 964 |
+
parser.add_argument("--corpus", default="corpus/braille_multimodal_corpus.txt")
|
| 965 |
+
parser.add_argument("--tokenizer", default="tokenizers/braille_8dot_32k/braille_8dot_32k.model")
|
| 966 |
+
parser.add_argument("--output", default="models/braille256_v6_lattice")
|
| 967 |
+
parser.add_argument("--steps", type=int, default=10000)
|
| 968 |
+
parser.add_argument("--batch-size", type=int, default=16)
|
| 969 |
+
parser.add_argument("--lr", type=float, default=3e-4)
|
| 970 |
+
parser.add_argument("--no-lattice-attention", action="store_true")
|
| 971 |
+
parser.add_argument("--no-lattice-embeddings", action="store_true")
|
| 972 |
+
parser.add_argument("--no-morph-regularization", action="store_true")
|
| 973 |
+
|
| 974 |
+
args = parser.parse_args()
|
| 975 |
+
|
| 976 |
+
train(
|
| 977 |
+
corpus_path=args.corpus,
|
| 978 |
+
tokenizer_path=args.tokenizer,
|
| 979 |
+
output_dir=args.output,
|
| 980 |
+
max_steps=args.steps,
|
| 981 |
+
batch_size=args.batch_size,
|
| 982 |
+
learning_rate=args.lr,
|
| 983 |
+
use_lattice_attention=not args.no_lattice_attention,
|
| 984 |
+
use_lattice_embeddings=not args.no_lattice_embeddings,
|
| 985 |
+
use_morphological_regularization=not args.no_morph_regularization,
|
| 986 |
+
)
|
| 987 |
+
|
| 988 |
+
|
| 989 |
+
if __name__ == "__main__":
|
| 990 |
+
main()
|
training_log.json
ADDED
|
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"step": 100,
|
| 4 |
+
"lm_loss": 6.993148565292358,
|
| 5 |
+
"morph_loss": 0.0004997336654923856,
|
| 6 |
+
"lr": 5.9999999999999995e-05
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"step": 200,
|
| 10 |
+
"lm_loss": 3.8235212564468384,
|
| 11 |
+
"morph_loss": 0.0004995590425096452,
|
| 12 |
+
"lr": 0.00011999999999999999
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"step": 300,
|
| 16 |
+
"lm_loss": 2.6851329803466797,
|
| 17 |
+
"morph_loss": 0.0004995536583010107,
|
| 18 |
+
"lr": 0.00017999999999999998
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"step": 400,
|
| 22 |
+
"lm_loss": 2.7323516607284546,
|
| 23 |
+
"morph_loss": 0.0004988558066543192,
|
| 24 |
+
"lr": 0.00023999999999999998
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"step": 500,
|
| 28 |
+
"lm_loss": 2.6041159629821777,
|
| 29 |
+
"morph_loss": 0.000500149471918121,
|
| 30 |
+
"lr": 0.0003
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"step": 600,
|
| 34 |
+
"lm_loss": 2.7856509685516357,
|
| 35 |
+
"morph_loss": 0.0004991319729015231,
|
| 36 |
+
"lr": 0.0002999179886011389
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"step": 700,
|
| 40 |
+
"lm_loss": 1.9858170747756958,
|
| 41 |
+
"morph_loss": 0.0004987868887837976,
|
| 42 |
+
"lr": 0.00029967204408281613
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"step": 800,
|
| 46 |
+
"lm_loss": 2.2885484099388123,
|
| 47 |
+
"morph_loss": 0.0004994409391656518,
|
| 48 |
+
"lr": 0.0002992624353817517
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"step": 900,
|
| 52 |
+
"lm_loss": 1.7579542398452759,
|
| 53 |
+
"morph_loss": 0.000499250425491482,
|
| 54 |
+
"lr": 0.00029868961039904624
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"step": 1000,
|
| 58 |
+
"lm_loss": 2.356551766395569,
|
| 59 |
+
"morph_loss": 0.000498554261866957,
|
| 60 |
+
"lr": 0.00029795419551040833
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"step": 1100,
|
| 64 |
+
"lm_loss": 2.2041295170783997,
|
| 65 |
+
"morph_loss": 0.0004986616258975118,
|
| 66 |
+
"lr": 0.0002970569948812214
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"step": 1200,
|
| 70 |
+
"lm_loss": 1.987478256225586,
|
| 71 |
+
"morph_loss": 0.0004990812740288675,
|
| 72 |
+
"lr": 0.0002959989895872009
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"step": 1300,
|
| 76 |
+
"lm_loss": 1.7949504256248474,
|
| 77 |
+
"morph_loss": 0.0004993200418539345,
|
| 78 |
+
"lr": 0.0002947813365416023
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"step": 1400,
|
| 82 |
+
"lm_loss": 2.577250361442566,
|
| 83 |
+
"morph_loss": 0.0004996150964871049,
|
| 84 |
+
"lr": 0.0002934053672301536
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"step": 1500,
|
| 88 |
+
"lm_loss": 2.1883418560028076,
|
| 89 |
+
"morph_loss": 0.000496969121741131,
|
| 90 |
+
"lr": 0.00029187258625509513
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"step": 1600,
|
| 94 |
+
"lm_loss": 1.70472252368927,
|
| 95 |
+
"morph_loss": 0.0001988508302019909,
|
| 96 |
+
"lr": 0.0002901846696899191
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"step": 1700,
|
| 100 |
+
"lm_loss": 2.3121373057365417,
|
| 101 |
+
"morph_loss": 0.0001993859259528108,
|
| 102 |
+
"lr": 0.0002883434632466077
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"step": 1800,
|
| 106 |
+
"lm_loss": 2.0045361518859863,
|
| 107 |
+
"morph_loss": 0.00019962265650974587,
|
| 108 |
+
"lr": 0.00028635098025737434
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"step": 1900,
|
| 112 |
+
"lm_loss": 1.9210466742515564,
|
| 113 |
+
"morph_loss": 0.0001999259038711898,
|
| 114 |
+
"lr": 0.0002842093994731145
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"step": 2000,
|
| 118 |
+
"lm_loss": 2.045823335647583,
|
| 119 |
+
"morph_loss": 0.00019963263912359253,
|
| 120 |
+
"lr": 0.00028192106268097334
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"step": 2100,
|
| 124 |
+
"lm_loss": 2.363018274307251,
|
| 125 |
+
"morph_loss": 0.00020041707466589287,
|
| 126 |
+
"lr": 0.0002794884721436361
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"step": 2200,
|
| 130 |
+
"lm_loss": 2.146875023841858,
|
| 131 |
+
"morph_loss": 0.00019886076188413426,
|
| 132 |
+
"lr": 0.0002769142878631403
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"step": 2300,
|
| 136 |
+
"lm_loss": 1.7959808111190796,
|
| 137 |
+
"morph_loss": 0.0001991742683458142,
|
| 138 |
+
"lr": 0.000274201324672203
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"step": 2400,
|
| 142 |
+
"lm_loss": 2.1792458295822144,
|
| 143 |
+
"morph_loss": 0.00020018102077301592,
|
| 144 |
+
"lr": 0.0002713525491562421
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"step": 2500,
|
| 148 |
+
"lm_loss": 2.177172005176544,
|
| 149 |
+
"morph_loss": 0.0001997477374970913,
|
| 150 |
+
"lr": 0.00026837107640945905
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"step": 2600,
|
| 154 |
+
"lm_loss": 1.8140788674354553,
|
| 155 |
+
"morph_loss": 0.00019913741562049836,
|
| 156 |
+
"lr": 0.00026526016662852886
|
| 157 |
+
},
|
| 158 |
+
{
|
| 159 |
+
"step": 2700,
|
| 160 |
+
"lm_loss": 2.1109378337860107,
|
| 161 |
+
"morph_loss": 0.00019985359540442005,
|
| 162 |
+
"lr": 0.0002620232215476231
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"step": 2800,
|
| 166 |
+
"lm_loss": 1.8451208472251892,
|
| 167 |
+
"morph_loss": 0.0002002850960707292,
|
| 168 |
+
"lr": 0.00025866378071866334
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"step": 2900,
|
| 172 |
+
"lm_loss": 1.4895858764648438,
|
| 173 |
+
"morph_loss": 0.00019962178339483216,
|
| 174 |
+
"lr": 0.00025518551764087326
|
| 175 |
+
},
|
| 176 |
+
{
|
| 177 |
+
"step": 3000,
|
| 178 |
+
"lm_loss": 1.5707910060882568,
|
| 179 |
+
"morph_loss": 0.00019960849022027105,
|
| 180 |
+
"lr": 0.00025159223574386114
|
| 181 |
+
},
|
| 182 |
+
{
|
| 183 |
+
"step": 3100,
|
| 184 |
+
"lm_loss": 1.5902798175811768,
|
| 185 |
+
"morph_loss": 0.0002000557360588573,
|
| 186 |
+
"lr": 0.00024788786422862526
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"step": 3200,
|
| 190 |
+
"lm_loss": 1.8854023218154907,
|
| 191 |
+
"morph_loss": 0.00019861374312313274,
|
| 192 |
+
"lr": 0.00024407645377103054
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"step": 3300,
|
| 196 |
+
"lm_loss": 1.6806678175926208,
|
| 197 |
+
"morph_loss": 0.00020012820459669456,
|
| 198 |
+
"lr": 0.00024016217209245374
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"step": 3400,
|
| 202 |
+
"lm_loss": 1.827455759048462,
|
| 203 |
+
"morph_loss": 0.00020046472491230816,
|
| 204 |
+
"lr": 0.0002361492994024415
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"step": 3500,
|
| 208 |
+
"lm_loss": 1.9594002962112427,
|
| 209 |
+
"morph_loss": 0.00019966769468737766,
|
| 210 |
+
"lr": 0.00023204222371836402
|
| 211 |
+
},
|
| 212 |
+
{
|
| 213 |
+
"step": 3600,
|
| 214 |
+
"lm_loss": 1.359905481338501,
|
| 215 |
+
"morph_loss": 0.00019929300469812006,
|
| 216 |
+
"lr": 0.00022784543606718227
|
| 217 |
+
},
|
| 218 |
+
{
|
| 219 |
+
"step": 3700,
|
| 220 |
+
"lm_loss": 1.9533899426460266,
|
| 221 |
+
"morph_loss": 0.00020049385057063773,
|
| 222 |
+
"lr": 0.0002235635255745762
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"step": 3800,
|
| 226 |
+
"lm_loss": 1.4309832453727722,
|
| 227 |
+
"morph_loss": 0.00019867864466505125,
|
| 228 |
+
"lr": 0.00021920117444680317
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"step": 3900,
|
| 232 |
+
"lm_loss": 1.2603670358657837,
|
| 233 |
+
"morph_loss": 0.00020007100101793185,
|
| 234 |
+
"lr": 0.0002147631528507739
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"step": 4000,
|
| 238 |
+
"lm_loss": 1.359400749206543,
|
| 239 |
+
"morph_loss": 0.00019884618814103305,
|
| 240 |
+
"lr": 0.0002102543136979454
|
| 241 |
+
},
|
| 242 |
+
{
|
| 243 |
+
"step": 4100,
|
| 244 |
+
"lm_loss": 1.4845112562179565,
|
| 245 |
+
"morph_loss": 4.9905273044714704e-05,
|
| 246 |
+
"lr": 0.0002056795873377331
|
| 247 |
+
},
|
| 248 |
+
{
|
| 249 |
+
"step": 4200,
|
| 250 |
+
"lm_loss": 1.3849643468856812,
|
| 251 |
+
"morph_loss": 4.9970141844823956e-05,
|
| 252 |
+
"lr": 0.00020104397616624645
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
"step": 4300,
|
| 256 |
+
"lm_loss": 1.293235957622528,
|
| 257 |
+
"morph_loss": 4.9802090870798565e-05,
|
| 258 |
+
"lr": 0.0001963525491562421
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"step": 4400,
|
| 262 |
+
"lm_loss": 1.8045696020126343,
|
| 263 |
+
"morph_loss": 5.000879900762811e-05,
|
| 264 |
+
"lr": 0.00019161043631427666
|
| 265 |
+
},
|
| 266 |
+
{
|
| 267 |
+
"step": 4500,
|
| 268 |
+
"lm_loss": 1.472623884677887,
|
| 269 |
+
"morph_loss": 4.9775953812059015e-05,
|
| 270 |
+
"lr": 0.00018682282307111987
|
| 271 |
+
},
|
| 272 |
+
{
|
| 273 |
+
"step": 4600,
|
| 274 |
+
"lm_loss": 1.4265462756156921,
|
| 275 |
+
"morph_loss": 4.9740716349333525e-05,
|
| 276 |
+
"lr": 0.00018199494461156203
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
"step": 4700,
|
| 280 |
+
"lm_loss": 1.3741839528083801,
|
| 281 |
+
"morph_loss": 5.005610182706732e-05,
|
| 282 |
+
"lr": 0.00017713208014981648
|
| 283 |
+
},
|
| 284 |
+
{
|
| 285 |
+
"step": 4800,
|
| 286 |
+
"lm_loss": 1.5047972798347473,
|
| 287 |
+
"morph_loss": 4.9890166337718256e-05,
|
| 288 |
+
"lr": 0.00017223954715677627
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
"step": 4900,
|
| 292 |
+
"lm_loss": 1.9711642265319824,
|
| 293 |
+
"morph_loss": 5.0176731747342274e-05,
|
| 294 |
+
"lr": 0.00016732269554543794
|
| 295 |
+
},
|
| 296 |
+
{
|
| 297 |
+
"step": 5000,
|
| 298 |
+
"lm_loss": 1.3522289395332336,
|
| 299 |
+
"morph_loss": 5.020072603656445e-05,
|
| 300 |
+
"lr": 0.00016238690182084986
|
| 301 |
+
},
|
| 302 |
+
{
|
| 303 |
+
"step": 5100,
|
| 304 |
+
"lm_loss": 1.7605299353599548,
|
| 305 |
+
"morph_loss": 4.991166679246817e-05,
|
| 306 |
+
"lr": 0.00015743756320098332
|
| 307 |
+
},
|
| 308 |
+
{
|
| 309 |
+
"step": 5200,
|
| 310 |
+
"lm_loss": 1.4882904291152954,
|
| 311 |
+
"morph_loss": 4.9983715143753216e-05,
|
| 312 |
+
"lr": 0.00015248009171495378
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"step": 5300,
|
| 316 |
+
"lm_loss": 1.6343830823898315,
|
| 317 |
+
"morph_loss": 5.002636680728756e-05,
|
| 318 |
+
"lr": 0.00014751990828504622
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
"step": 5400,
|
| 322 |
+
"lm_loss": 1.1790239810943604,
|
| 323 |
+
"morph_loss": 4.997515679860953e-05,
|
| 324 |
+
"lr": 0.00014256243679901663
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"step": 5500,
|
| 328 |
+
"lm_loss": 1.2620559334754944,
|
| 329 |
+
"morph_loss": 4.979184268449899e-05,
|
| 330 |
+
"lr": 0.00013761309817915014
|
| 331 |
+
},
|
| 332 |
+
{
|
| 333 |
+
"step": 5600,
|
| 334 |
+
"lm_loss": 1.2926940321922302,
|
| 335 |
+
"morph_loss": 4.9720953029464e-05,
|
| 336 |
+
"lr": 0.00013267730445456208
|
| 337 |
+
},
|
| 338 |
+
{
|
| 339 |
+
"step": 5700,
|
| 340 |
+
"lm_loss": 1.6810715198516846,
|
| 341 |
+
"morph_loss": 5.0059263230650686e-05,
|
| 342 |
+
"lr": 0.00012776045284322368
|
| 343 |
+
},
|
| 344 |
+
{
|
| 345 |
+
"step": 5800,
|
| 346 |
+
"lm_loss": 1.3960903882980347,
|
| 347 |
+
"morph_loss": 5.0068707423633896e-05,
|
| 348 |
+
"lr": 0.00012286791985018355
|
| 349 |
+
},
|
| 350 |
+
{
|
| 351 |
+
"step": 5900,
|
| 352 |
+
"lm_loss": 1.5417255759239197,
|
| 353 |
+
"morph_loss": 5.0141115934820846e-05,
|
| 354 |
+
"lr": 0.00011800505538843798
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
"step": 6000,
|
| 358 |
+
"lm_loss": 1.3895499110221863,
|
| 359 |
+
"morph_loss": 4.9999320253846236e-05,
|
| 360 |
+
"lr": 0.00011317717692888012
|
| 361 |
+
},
|
| 362 |
+
{
|
| 363 |
+
"step": 6100,
|
| 364 |
+
"lm_loss": 1.578350841999054,
|
| 365 |
+
"morph_loss": 4.975642514182255e-05,
|
| 366 |
+
"lr": 0.00010838956368572334
|
| 367 |
+
},
|
| 368 |
+
{
|
| 369 |
+
"step": 6200,
|
| 370 |
+
"lm_loss": 0.8763712048530579,
|
| 371 |
+
"morph_loss": 4.9885256885318086e-05,
|
| 372 |
+
"lr": 0.0001036474508437579
|
| 373 |
+
},
|
| 374 |
+
{
|
| 375 |
+
"step": 6300,
|
| 376 |
+
"lm_loss": 1.3805139064788818,
|
| 377 |
+
"morph_loss": 4.985421219316777e-05,
|
| 378 |
+
"lr": 9.895602383375353e-05
|
| 379 |
+
},
|
| 380 |
+
{
|
| 381 |
+
"step": 6400,
|
| 382 |
+
"lm_loss": 1.642943263053894,
|
| 383 |
+
"morph_loss": 4.988547880202532e-05,
|
| 384 |
+
"lr": 9.432041266226686e-05
|
| 385 |
+
},
|
| 386 |
+
{
|
| 387 |
+
"step": 6500,
|
| 388 |
+
"lm_loss": 1.2295689284801483,
|
| 389 |
+
"morph_loss": 4.976921445631888e-05,
|
| 390 |
+
"lr": 8.97456863020546e-05
|
| 391 |
+
},
|
| 392 |
+
{
|
| 393 |
+
"step": 6600,
|
| 394 |
+
"lm_loss": 0.9539550840854645,
|
| 395 |
+
"morph_loss": 4.960754813509993e-05,
|
| 396 |
+
"lr": 8.523684714922608e-05
|
| 397 |
+
},
|
| 398 |
+
{
|
| 399 |
+
"step": 6700,
|
| 400 |
+
"lm_loss": 1.4480910301208496,
|
| 401 |
+
"morph_loss": 4.9842370572150685e-05,
|
| 402 |
+
"lr": 8.079882555319683e-05
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"step": 6800,
|
| 406 |
+
"lm_loss": 1.1316336393356323,
|
| 407 |
+
"morph_loss": 4.944742067891639e-05,
|
| 408 |
+
"lr": 7.643647442542382e-05
|
| 409 |
+
},
|
| 410 |
+
{
|
| 411 |
+
"step": 6900,
|
| 412 |
+
"lm_loss": 1.2974263429641724,
|
| 413 |
+
"morph_loss": 4.9362375648343004e-05,
|
| 414 |
+
"lr": 7.215456393281776e-05
|
| 415 |
+
},
|
| 416 |
+
{
|
| 417 |
+
"step": 7000,
|
| 418 |
+
"lm_loss": 1.8624339699745178,
|
| 419 |
+
"morph_loss": 4.9819502237369306e-05,
|
| 420 |
+
"lr": 6.795777628163599e-05
|
| 421 |
+
},
|
| 422 |
+
{
|
| 423 |
+
"step": 7100,
|
| 424 |
+
"lm_loss": 1.2204494774341583,
|
| 425 |
+
"morph_loss": 5.013305417378433e-05,
|
| 426 |
+
"lr": 6.385070059755846e-05
|
| 427 |
+
},
|
| 428 |
+
{
|
| 429 |
+
"step": 7200,
|
| 430 |
+
"lm_loss": 1.5136797428131104,
|
| 431 |
+
"morph_loss": 4.9816295359050855e-05,
|
| 432 |
+
"lr": 5.983782790754623e-05
|
| 433 |
+
},
|
| 434 |
+
{
|
| 435 |
+
"step": 7300,
|
| 436 |
+
"lm_loss": 1.3666653037071228,
|
| 437 |
+
"morph_loss": 4.99475918331882e-05,
|
| 438 |
+
"lr": 5.592354622896944e-05
|
| 439 |
+
},
|
| 440 |
+
{
|
| 441 |
+
"step": 7400,
|
| 442 |
+
"lm_loss": 0.8389511108398438,
|
| 443 |
+
"morph_loss": 4.9609083362156525e-05,
|
| 444 |
+
"lr": 5.211213577137469e-05
|
| 445 |
+
},
|
| 446 |
+
{
|
| 447 |
+
"step": 7500,
|
| 448 |
+
"lm_loss": 1.1075031757354736,
|
| 449 |
+
"morph_loss": 4.941183760820422e-05,
|
| 450 |
+
"lr": 4.840776425613885e-05
|
| 451 |
+
},
|
| 452 |
+
{
|
| 453 |
+
"step": 7600,
|
| 454 |
+
"lm_loss": 1.2579197883605957,
|
| 455 |
+
"morph_loss": 4.980366429663263e-05,
|
| 456 |
+
"lr": 4.481448235912671e-05
|
| 457 |
+
},
|
| 458 |
+
{
|
| 459 |
+
"step": 7700,
|
| 460 |
+
"lm_loss": 1.0307890474796295,
|
| 461 |
+
"morph_loss": 4.9867769121192396e-05,
|
| 462 |
+
"lr": 4.133621928133665e-05
|
| 463 |
+
},
|
| 464 |
+
{
|
| 465 |
+
"step": 7800,
|
| 466 |
+
"lm_loss": 1.0696255564689636,
|
| 467 |
+
"morph_loss": 5.0295926484977826e-05,
|
| 468 |
+
"lr": 3.797677845237696e-05
|
| 469 |
+
},
|
| 470 |
+
{
|
| 471 |
+
"step": 7900,
|
| 472 |
+
"lm_loss": 1.3926631212234497,
|
| 473 |
+
"morph_loss": 5.014805537939537e-05,
|
| 474 |
+
"lr": 3.473983337147118e-05
|
| 475 |
+
},
|
| 476 |
+
{
|
| 477 |
+
"step": 8000,
|
| 478 |
+
"lm_loss": 1.5005779266357422,
|
| 479 |
+
"morph_loss": 5.018570300308056e-05,
|
| 480 |
+
"lr": 3.162892359054098e-05
|
| 481 |
+
},
|
| 482 |
+
{
|
| 483 |
+
"step": 8100,
|
| 484 |
+
"lm_loss": 1.7105327248573303,
|
| 485 |
+
"morph_loss": 5.010495260648895e-05,
|
| 486 |
+
"lr": 2.8647450843757897e-05
|
| 487 |
+
},
|
| 488 |
+
{
|
| 489 |
+
"step": 8200,
|
| 490 |
+
"lm_loss": 1.229815423488617,
|
| 491 |
+
"morph_loss": 4.9758424211177044e-05,
|
| 492 |
+
"lr": 2.5798675327796993e-05
|
| 493 |
+
},
|
| 494 |
+
{
|
| 495 |
+
"step": 8300,
|
| 496 |
+
"lm_loss": 1.1335912346839905,
|
| 497 |
+
"morph_loss": 4.941884253639728e-05,
|
| 498 |
+
"lr": 2.3085712136859668e-05
|
| 499 |
+
},
|
| 500 |
+
{
|
| 501 |
+
"step": 8400,
|
| 502 |
+
"lm_loss": 1.0056449174880981,
|
| 503 |
+
"morph_loss": 4.942774467053823e-05,
|
| 504 |
+
"lr": 2.0511527856363895e-05
|
| 505 |
+
},
|
| 506 |
+
{
|
| 507 |
+
"step": 8500,
|
| 508 |
+
"lm_loss": 1.6661915183067322,
|
| 509 |
+
"morph_loss": 4.997232463210821e-05,
|
| 510 |
+
"lr": 1.8078937319026654e-05
|
| 511 |
+
},
|
| 512 |
+
{
|
| 513 |
+
"step": 8600,
|
| 514 |
+
"lm_loss": 1.1359021067619324,
|
| 515 |
+
"morph_loss": 4.9868809583131224e-05,
|
| 516 |
+
"lr": 1.579060052688548e-05
|
| 517 |
+
},
|
| 518 |
+
{
|
| 519 |
+
"step": 8700,
|
| 520 |
+
"lm_loss": 1.434115707874298,
|
| 521 |
+
"morph_loss": 4.993028596800286e-05,
|
| 522 |
+
"lr": 1.3649019742625623e-05
|
| 523 |
+
},
|
| 524 |
+
{
|
| 525 |
+
"step": 8800,
|
| 526 |
+
"lm_loss": 1.2552986145019531,
|
| 527 |
+
"morph_loss": 4.9796975872595794e-05,
|
| 528 |
+
"lr": 1.1656536753392287e-05
|
| 529 |
+
},
|
| 530 |
+
{
|
| 531 |
+
"step": 8900,
|
| 532 |
+
"lm_loss": 1.217362403869629,
|
| 533 |
+
"morph_loss": 4.991148489352781e-05,
|
| 534 |
+
"lr": 9.815330310080887e-06
|
| 535 |
+
},
|
| 536 |
+
{
|
| 537 |
+
"step": 9000,
|
| 538 |
+
"lm_loss": 1.7755168080329895,
|
| 539 |
+
"morph_loss": 5.013133522879798e-05,
|
| 540 |
+
"lr": 8.127413744904804e-06
|
| 541 |
+
},
|
| 542 |
+
{
|
| 543 |
+
"step": 9100,
|
| 544 |
+
"lm_loss": 1.3130499720573425,
|
| 545 |
+
"morph_loss": 5.007252184441313e-05,
|
| 546 |
+
"lr": 6.594632769846353e-06
|
| 547 |
+
},
|
| 548 |
+
{
|
| 549 |
+
"step": 9200,
|
| 550 |
+
"lm_loss": 1.1731150150299072,
|
| 551 |
+
"morph_loss": 5.0212831411045045e-05,
|
| 552 |
+
"lr": 5.218663458397715e-06
|
| 553 |
+
},
|
| 554 |
+
{
|
| 555 |
+
"step": 9300,
|
| 556 |
+
"lm_loss": 0.9502497613430023,
|
| 557 |
+
"morph_loss": 4.9593836592976004e-05,
|
| 558 |
+
"lr": 4.001010412799138e-06
|
| 559 |
+
},
|
| 560 |
+
{
|
| 561 |
+
"step": 9400,
|
| 562 |
+
"lm_loss": 1.3233891725540161,
|
| 563 |
+
"morph_loss": 4.999847624276299e-05,
|
| 564 |
+
"lr": 2.9430051187785962e-06
|
| 565 |
+
},
|
| 566 |
+
{
|
| 567 |
+
"step": 9500,
|
| 568 |
+
"lm_loss": 1.3283841013908386,
|
| 569 |
+
"morph_loss": 5.0088236093870364e-05,
|
| 570 |
+
"lr": 2.0458044895916513e-06
|
| 571 |
+
},
|
| 572 |
+
{
|
| 573 |
+
"step": 9600,
|
| 574 |
+
"lm_loss": 1.2733866572380066,
|
| 575 |
+
"morph_loss": 4.998848271497991e-05,
|
| 576 |
+
"lr": 1.3103896009537207e-06
|
| 577 |
+
},
|
| 578 |
+
{
|
| 579 |
+
"step": 9700,
|
| 580 |
+
"lm_loss": 1.1967694163322449,
|
| 581 |
+
"morph_loss": 5.023612902732566e-05,
|
| 582 |
+
"lr": 7.375646182482875e-07
|
| 583 |
+
},
|
| 584 |
+
{
|
| 585 |
+
"step": 9800,
|
| 586 |
+
"lm_loss": 1.0563868880271912,
|
| 587 |
+
"morph_loss": 4.972490751242731e-05,
|
| 588 |
+
"lr": 3.2795591718381975e-07
|
| 589 |
+
},
|
| 590 |
+
{
|
| 591 |
+
"step": 9900,
|
| 592 |
+
"lm_loss": 1.7780798077583313,
|
| 593 |
+
"morph_loss": 4.9859330829349346e-05,
|
| 594 |
+
"lr": 8.201139886109264e-08
|
| 595 |
+
},
|
| 596 |
+
{
|
| 597 |
+
"step": 10000,
|
| 598 |
+
"lm_loss": 1.231259286403656,
|
| 599 |
+
"morph_loss": 4.956562952429522e-05,
|
| 600 |
+
"lr": 0.0
|
| 601 |
+
}
|
| 602 |
+
]
|