Zandy-Wandy commited on
Commit
bf64b03
·
verified ·
1 Parent(s): 83c011e

Upload Vortex model

Browse files
Files changed (50) hide show
  1. README.md +289 -0
  2. configs/__pycache__/vortex_7b_config.cpython-313.pyc +0 -0
  3. configs/training_config.py +97 -0
  4. configs/vortex_13b_config.py +68 -0
  5. configs/vortex_7b_config.py +71 -0
  6. configuration_vortex.py +110 -0
  7. cuda_optimize.py +287 -0
  8. data/__pycache__/deduplication.cpython-313.pyc +0 -0
  9. data/__pycache__/domain_classifier.cpython-313.pyc +0 -0
  10. data/__pycache__/quality_filter.cpython-313.pyc +0 -0
  11. data/dataset_loader.py +263 -0
  12. data/deduplication.py +260 -0
  13. data/domain_classifier.py +163 -0
  14. data/quality_filter.py +279 -0
  15. data/scraper.py +405 -0
  16. inference.py +213 -0
  17. modeling_vortex.py +222 -0
  18. models/__pycache__/attention_layer.cpython-313.pyc +0 -0
  19. models/__pycache__/scigate_ffn.cpython-313.pyc +0 -0
  20. models/__pycache__/ssm_layer.cpython-313.pyc +0 -0
  21. models/__pycache__/vortex_model.cpython-313.pyc +0 -0
  22. models/attention_layer.py +370 -0
  23. models/science_modules/__init__.py +15 -0
  24. models/science_modules/__pycache__/__init__.cpython-313.pyc +0 -0
  25. models/science_modules/__pycache__/citation_module.cpython-313.pyc +0 -0
  26. models/science_modules/__pycache__/equation_module.cpython-313.pyc +0 -0
  27. models/science_modules/__pycache__/molecular_module.cpython-313.pyc +0 -0
  28. models/science_modules/__pycache__/numerical_module.cpython-313.pyc +0 -0
  29. models/science_modules/citation_module.py +230 -0
  30. models/science_modules/equation_module.py +266 -0
  31. models/science_modules/molecular_module.py +333 -0
  32. models/science_modules/numerical_module.py +251 -0
  33. models/scigate_ffn.py +203 -0
  34. models/ssm_layer.py +252 -0
  35. models/vortex_model.py +377 -0
  36. mps_optimize.py +172 -0
  37. push_to_hf.py +39 -0
  38. requirements.txt +50 -0
  39. science_bench.py +360 -0
  40. test_model.py +449 -0
  41. tokenization_vortex.py +174 -0
  42. tokenizer/__pycache__/vortex_tokenizer.cpython-313.pyc +0 -0
  43. tokenizer/vortex_tokenizer.py +442 -0
  44. train.py +146 -0
  45. training/__pycache__/curriculum.cpython-313.pyc +0 -0
  46. training/__pycache__/losses.cpython-313.pyc +0 -0
  47. training/curriculum.py +175 -0
  48. training/losses.py +162 -0
  49. training/trainer.py +442 -0
  50. vortex_config.py +71 -0
README.md ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vortex Scientific
2
+
3
+ **Vortex Scientific** is a from-scratch AI model family designed for deep scientific reasoning. Built from the ground up with a novel hybrid state-space + attention architecture, optimized for consumer laptop hardware (Apple Silicon MacBooks and Nvidia 4060 laptop GPUs).
4
+
5
+ ## 🌟 Features
6
+
7
+ - **Novel Architecture**: Hybrid State-Space Model (SSM) + Local Attention blocks
8
+ - **Science-Specialized**: Custom tokenizer, domain-aware gating, and specialized modules for equations, numerical reasoning, citations, and molecular structures
9
+ - **Hardware Optimized**: Runs smoothly on 8GB VRAM (4060 laptop) and 16GB unified memory (MacBook Pro M2/M3)
10
+ - **Two Model Sizes**:
11
+ - **Vortex-7B**: 7 billion parameters, fits in 8GB VRAM
12
+ - **Vortex-13B**: 13 billion parameters, fits in 16GB VRAM with quantization
13
+ - **HuggingFace Compatible**: Full integration with `transformers` library
14
+ - **From Scratch**: No base model — everything built bottom-up including tokenizer and weights
15
+
16
+ ## 🏗️ Architecture
17
+
18
+ Vortex uses a two-block hybrid architecture:
19
+
20
+ 1. **SSM-Only Blocks**: State-space layers for efficient long-context processing (O(n) complexity)
21
+ 2. **Attention+Science Blocks**: Local windowed attention + science modules + SciGate FFN
22
+
23
+ Layer ratios:
24
+ - 7B: 60% SSM, 40% Attention (pattern: SSM, SSM, Attn, ...)
25
+ - 13B: 50% SSM, 50% Attention (pattern: SSM, Attn, SSM, Attn, ...)
26
+
27
+ ### Science Modules
28
+
29
+ - **EquationModule**: LaTeX equation detection and structural understanding
30
+ - **NumericalReasoningModule**: Digit-level encoding, scientific notation, unit awareness
31
+ - **CitationModule**: Citation span detection, provenance tracking, confidence scoring
32
+ - **MolecularModule**: Element embeddings, SMILES understanding, amino acid sequences
33
+
34
+ ## 📦 Project Structure
35
+
36
+ ```
37
+ Vortex/
38
+ ├── configs/
39
+ │ ├── vortex_7b_config.py # 7B model configuration
40
+ │ ├── vortex_13b_config.py # 13B model configuration
41
+ │ └── training_config.py # Training hyperparameters
42
+ ├── models/
43
+ │ ├── ssm_layer.py # State-space layer
44
+ │ ├── attention_layer.py # Local windowed attention
45
+ │ ├── scigate_ffn.py # Science-gated feed-forward
46
+ │ ├── vortex_model.py # Main model class
47
+ │ └── science_modules/ # Specialized science modules
48
+ ├── tokenizer/
49
+ │ └── vortex_tokenizer.py # Custom BPE tokenizer with science vocab
50
+ ├── data/
51
+ │ ├── dataset_loader.py # Open dataset loading (Pile, S2ORC, etc.)
52
+ │ ├── quality_filter.py # Multi-stage quality filtering
53
+ │ ├── domain_classifier.py # 7-domain classifier
54
+ │ ├── deduplication.py # MinHash LSH deduplication
55
+ │ └── scraper.py # Web scraping (arXiv, PubMed, etc.)
56
+ ├── training/
57
+ │ ├── trainer.py # Main training loop
58
+ │ ├── losses.py # Science-aware loss functions
59
+ │ └── curriculum.py # Curriculum learning scheduler
60
+ ├── inference/
61
+ │ ├── cuda_optimize.py # CUDA optimizations (Flash Attention, INT8)
62
+ │ └── mps_optimize.py # MPS optimizations for Apple Silicon
63
+ ├── evaluation/ # Science benchmarks (coming soon)
64
+ ├── configuration_vortex.py # HF config class
65
+ ├── tokenization_vortex.py # HF tokenizer wrapper
66
+ ├── modeling_vortex.py # HF model integration
67
+ ├── train.py # Training entry point
68
+ ├── inference/inference.py # Inference entry point
69
+ └── requirements.txt
70
+ ```
71
+
72
+ ## 🚀 Quick Start
73
+
74
+ ### Installation
75
+
76
+ ```bash
77
+ # Clone and setup
78
+ cd Vortex
79
+ pip install -r requirements.txt
80
+
81
+ # For CUDA optimizations
82
+ pip install flash-attn
83
+ pip install bitsandbytes
84
+ ```
85
+
86
+ ### Training
87
+
88
+ ```bash
89
+ # Train 7B model on CUDA
90
+ python train.py \
91
+ --model_size 7b \
92
+ --device cuda \
93
+ --data_dir ./data/processed \
94
+ --output_dir ./checkpoints \
95
+ --max_steps 100000
96
+
97
+ # Train 13B model with INT8 quantization (for 8GB VRAM)
98
+ python train.py \
99
+ --model_size 13b \
100
+ --device cuda \
101
+ --quantization int8 \
102
+ --data_dir ./data/processed \
103
+ --output_dir ./checkpoints_13b
104
+ ```
105
+
106
+ ### Inference
107
+
108
+ ```bash
109
+ # Generate text with 7B model
110
+ python inference/inference.py \
111
+ --model_path ./checkpoints/latest.pt \
112
+ --model_size 7b \
113
+ --device cuda \
114
+ --prompt "The equation E = mc^2 describes" \
115
+ --max_new_tokens 100
116
+
117
+ # Interactive mode
118
+ python inference/inference.py \
119
+ --model_path ./checkpoints/latest.pt \
120
+ --model_size 7b \
121
+ --device cuda \
122
+ --interactive
123
+
124
+ # On Apple Silicon (MPS)
125
+ python inference/inference.py \
126
+ --model_path ./checkpoints/latest.pt \
127
+ --model_size 7b \
128
+ --use_mps \
129
+ --prompt "Explain quantum mechanics"
130
+ ```
131
+
132
+ ### HuggingFace Integration
133
+
134
+ ```python
135
+ from transformers import AutoModelForCausalLM, AutoTokenizer
136
+
137
+ # Load model and tokenizer
138
+ model = AutoModelForCausalLM.from_pretrained("./checkpoints")
139
+ tokenizer = AutoTokenizer.from_pretrained("./checkpoints")
140
+
141
+ # Generate
142
+ input_text = "The energy of a photon is given by"
143
+ inputs = tokenizer(input_text, return_tensors="pt")
144
+ outputs = model.generate(**inputs, max_new_tokens=50)
145
+ print(tokenizer.decode(outputs[0]))
146
+ ```
147
+
148
+ ## 📊 Data Pipeline
149
+
150
+ 1. **Open Datasets**: Automatically download from HuggingFace (Pile, S2ORC, Math datasets, PubMed QA)
151
+ 2. **Quality Filtering**: Multi-stage checks (length, language, equations, repetition, citations)
152
+ 3. **Deduplication**: MinHash LSH for near-duplicate detection
153
+ 4. **Domain Classification**: Classify into 7 science domains
154
+ 5. **Tokenization**: Custom science-aware BPE tokenizer
155
+ 6. **Sharding**: Write to Parquet with statistics
156
+
157
+ ```python
158
+ from data.dataset_loader import VortexDatasetLoader
159
+ from data.quality_filter import ScienceQualityFilter
160
+ from data.deduplication import MinHashLSH
161
+
162
+ # Load and process data
163
+ loader = VortexDatasetLoader()
164
+ quality_filter = ScienceQualityFilter()
165
+ lsh = MinHashLSH()
166
+
167
+ # Stream datasets, filter, deduplicate, and shard
168
+ for sample in loader.load_multiple_datasets(["pile_scientific", "automath"]):
169
+ if quality_filter.filter(sample["text"]):
170
+ lsh.add_document(sample["id"], sample["text"])
171
+ # Tokenize and save
172
+ ```
173
+
174
+ ## 🎯 Training Strategy
175
+
176
+ ### Curriculum Learning
177
+
178
+ Training progresses through 4 stages:
179
+
180
+ 1. **Foundation** (0-20%): Basic science text, simple equations, definitions
181
+ 2. **Domain** (20-50%): Domain-specific deep content per science area
182
+ 3. **Reasoning** (50-80%): Scientific problem solving, multi-step derivations
183
+ 4. **Integration** (80-100%): Cross-domain science, full dataset
184
+
185
+ ### Science-Aware Loss
186
+
187
+ ```python
188
+ total_loss = (
189
+ lm_loss * 1.0 # Standard next token prediction
190
+ + equation_loss * 0.3 # Equation reconstruction accuracy
191
+ + domain_loss * 0.1 # Domain classification head
192
+ + citation_loss * 0.1 # Citation detection accuracy
193
+ + numerical_loss * 0.2 # Numerical reasoning accuracy
194
+ )
195
+ ```
196
+
197
+ ## ⚙️ Configuration
198
+
199
+ ### 7B Config (VORTEX_7B_CONFIG)
200
+
201
+ - `d_model`: 4096
202
+ - `num_layers`: 32
203
+ - `num_heads`: 32
204
+ - `d_state`: 16
205
+ - `ssm_ratio`: 0.6
206
+ - `vocab_size`: 50000
207
+ - `max_seq_len`: 16384
208
+
209
+ ### 13B Config (VORTEX_13B_CONFIG)
210
+
211
+ - `d_model`: 5120
212
+ - `num_layers`: 40
213
+ - `num_heads`: 40
214
+ - `d_state`: 32
215
+ - `ssm_ratio`: 0.5
216
+ - `vocab_size`: 50000
217
+ - `max_seq_len`: 16384
218
+
219
+ ## 🔧 Hardware Targets
220
+
221
+ ### Nvidia 4060 Laptop (8GB VRAM)
222
+
223
+ - **7B**: BF16, no quantization, Flash Attention 2, torch.compile
224
+ - **13B**: INT8 quantization, Flash Attention 2, torch.compile
225
+ - Target TPS: 25-40 (7B), 15-25 (13B)
226
+
227
+ ### Apple Silicon (M2/M3)
228
+
229
+ - **7B on M3**: BF16 (via float16), SDPA, no compile
230
+ - **13B on M3 Max**: BF16, unified memory, SDPA
231
+ - Target TPS: 20-35 (7B), 12-20 (13B)
232
+
233
+ ## 🧪 Science Domains
234
+
235
+ 1. **Physics** (`[PHYS]`)
236
+ 2. **Mathematics** (`[MATH]`)
237
+ 3. **Chemistry** (`[CHEM]`)
238
+ 4. **Biology** (`[BIO]`)
239
+ 5. **Earth Science** (`[EARTH]`)
240
+ 6. **Space Science** (`[SPACE]`)
241
+ 7. **Zoology** (`[ZOO]`)
242
+
243
+ Domain tags can be included in training data to guide the SciGate FFN routing.
244
+
245
+ ## 📝 Tokenizer
246
+
247
+ Custom BPE tokenizer with:
248
+
249
+ - 40,000 base BPE tokens trained on scientific corpus
250
+ - 10,000 science-specific tokens:
251
+ - 500 LaTeX math symbols (`\alpha`, `\sum`, `\int`, etc.)
252
+ - 118 chemical element symbols
253
+ - 200 SI and derived units
254
+ - 300 scientific abbreviations (DNA, RNA, ATP, etc.)
255
+ - 500 mathematical operators
256
+ - Amino acid codes
257
+ - Greek alphabet (α, β, γ, etc.)
258
+ - Special tokens: `[EQUATION]`, `[CITATION]`, `[MOLECULE]`, `[FIGURE]`, `[TABLE]`, domain tags
259
+
260
+ ## 🧪 Evaluation
261
+
262
+ Science benchmarks across all 7 domains will be added. Planned benchmarks:
263
+
264
+ - **Physics**: Feynman Questions, Physics GRE
265
+ - **Math**: MATH dataset, GSM8K
266
+ - **Chemistry**: Chemistry problem-solving, molecular property prediction
267
+ - **Biology**: PubMed QA, bioinformatics tasks
268
+ - **Earth Science**: Climate modeling questions
269
+ - **Space Science**: Astronomy problem sets
270
+ - **Zoology**: Species classification, ecological reasoning
271
+
272
+ ## 📄 License
273
+
274
+ This is a school science project. Code is provided for educational purposes.
275
+
276
+ ## 🙏 Acknowledgments
277
+
278
+ - **Mamba** (Gu et al.) for SSM architecture inspiration
279
+ - **Flash Attention** (Dao et al.) for efficient attention
280
+ - **HuggingFace** for transformers library
281
+ - All open scientific data sources: arXiv, PubMed, S2ORC, etc.
282
+
283
+ ## 📧 Contact
284
+
285
+ For questions or issues, please open an issue on GitHub.
286
+
287
+ ---
288
+
289
+ **Built with ❤️ for scientific AI research**
configs/__pycache__/vortex_7b_config.cpython-313.pyc ADDED
Binary file (1.78 kB). View file
 
configs/training_config.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training configuration for Vortex models.
3
+ Covers both 7B and 13B variants with hardware-specific optimizations.
4
+ """
5
+
6
+ import torch
7
+
8
+ TRAINING_CONFIG = {
9
+ # Training hyperparameters
10
+ "learning_rate": 3e-4,
11
+ "weight_decay": 0.1,
12
+ "beta1": 0.9,
13
+ "beta2": 0.95,
14
+ "clip_grad_norm": 1.0,
15
+
16
+ # Batch sizing
17
+ "global_batch_size": 512, # tokens per batch
18
+ "micro_batch_size": 8, # per GPU
19
+ "gradient_accumulation_steps": 4,
20
+
21
+ # Training schedule
22
+ "max_steps": 100000,
23
+ "warmup_steps": 2000,
24
+ "save_interval": 5000,
25
+ "eval_interval": 1000,
26
+ "log_interval": 100,
27
+
28
+ # Mixed precision
29
+ "use_amp": True,
30
+ "amp_dtype": torch.bfloat16,
31
+
32
+ # Optimizer
33
+ "optimizer": "AdamW",
34
+ "use_fused": True, # fused AdamW if available
35
+
36
+ # Curriculum learning stages (as fractions of max_steps)
37
+ "curriculum_stages": [
38
+ {"name": "foundation", "start": 0.0, "end": 0.2}, # 0-20%
39
+ {"name": "domain", "start": 0.2, "end": 0.5}, # 20-50%
40
+ {"name": "reasoning", "start": 0.5, "end": 0.8}, # 50-80%
41
+ {"name": "integration", "start": 0.8, "end": 1.0}, # 80-100%
42
+ ],
43
+
44
+ # Loss weights (science-aware loss)
45
+ "loss_weights": {
46
+ "lm_loss": 1.0,
47
+ "equation_loss": 0.3,
48
+ "domain_loss": 0.1,
49
+ "citation_loss": 0.1,
50
+ "numerical_loss": 0.2,
51
+ },
52
+
53
+ # Checkpointing
54
+ "checkpoint_dir": "checkpoints",
55
+ "save_optimizer_state": True,
56
+ "save_scheduler_state": True,
57
+
58
+ # Logging
59
+ "log_dir": "logs",
60
+ "use_wandb": False,
61
+ "wandb_project": "vortex-scientific",
62
+
63
+ # Data loading
64
+ "num_workers": 8,
65
+ "prefetch_factor": 2,
66
+ "pin_memory": True,
67
+
68
+ # Device configuration
69
+ "device": "cuda", # or "mps" for Apple Silicon
70
+ "use_mps": False,
71
+
72
+ # Quantization (for 13B on 8GB VRAM)
73
+ "quantization": None, # None, "int8", "int4"
74
+ }
75
+
76
+ # Hardware-specific overrides
77
+ TRAINING_CONFIG_7B_CUDA = TRAINING_CONFIG.copy()
78
+ TRAINING_CONFIG_7B_CUDA.update({
79
+ "device": "cuda",
80
+ "quantization": None,
81
+ "micro_batch_size": 8,
82
+ })
83
+
84
+ TRAINING_CONFIG_13B_CUDA = TRAINING_CONFIG.copy()
85
+ TRAINING_CONFIG_13B_CUDA.update({
86
+ "device": "cuda",
87
+ "quantization": "int8", # 13B needs INT8 on 8GB
88
+ "micro_batch_size": 4,
89
+ })
90
+
91
+ TRAINING_CONFIG_MPS = TRAINING_CONFIG.copy()
92
+ TRAINING_CONFIG_MPS.update({
93
+ "device": "mps",
94
+ "use_mps": True,
95
+ "use_amp": False, # MPS doesn't support bfloat16 AMP well
96
+ "micro_batch_size": 4,
97
+ })
configs/vortex_13b_config.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vortex-13B model configuration.
3
+ Optimized for 16GB VRAM (4060 Ti laptop) and MacBook Pro M3 Max.
4
+ """
5
+
6
+ VORTEX_13B_CONFIG = {
7
+ # Model dimensions
8
+ "d_model": 5120,
9
+ "num_layers": 40,
10
+ "num_heads": 40,
11
+ "head_dim": 128, # d_model // num_heads
12
+
13
+ # State-space layer parameters
14
+ "d_state": 32, # SSM state dimension (larger for bigger model)
15
+ "d_conv": 4, # SSM convolution width
16
+
17
+ # Attention parameters
18
+ "window_size": 512, # Local attention window
19
+ "use_flash_attention": True,
20
+
21
+ # Feed-forward parameters
22
+ "ffn_expansion": 4,
23
+ "num_domains": 7,
24
+ "vocab_size": 50000,
25
+ "max_seq_len": 16384,
26
+
27
+ # Layer ratio: 50% SSM, 50% attention (more memory for attention)
28
+ "ssm_ratio": 0.5,
29
+
30
+ # Data types
31
+ "dtype": "bfloat16",
32
+
33
+ # Special tokens (same as 7B)
34
+ "special_tokens": {
35
+ "[PAD]": 0,
36
+ "[UNK]": 1,
37
+ "[BOS]": 2,
38
+ "[EOS]": 3,
39
+ "[EQUATION]": 4,
40
+ "[/EQUATION]": 5,
41
+ "[CITATION]": 6,
42
+ "[/CITATION]": 7,
43
+ "[MOLECULE]": 8,
44
+ "[/MOLECULE]": 9,
45
+ "[FIGURE]": 10,
46
+ "[TABLE]": 11,
47
+ "[MATH]": 12,
48
+ "[CHEM]": 13,
49
+ "[BIO]": 14,
50
+ "[PHYS]": 15,
51
+ "[EARTH]": 16,
52
+ "[SPACE]": 17,
53
+ "[ZOO]": 18,
54
+ },
55
+
56
+ "domain_tags": ["[MATH]", "[CHEM]", "[BIO]", "[PHYS]", "[EARTH]", "[SPACE]", "[ZOO]"],
57
+
58
+ # Science module flags
59
+ "enable_equation_module": True,
60
+ "enable_numerical_module": True,
61
+ "enable_citation_module": True,
62
+ "enable_molecular_module": True,
63
+ }
64
+
65
+
66
+ def get_config():
67
+ """Return the 13B configuration dictionary."""
68
+ return VORTEX_13B_CONFIG
configs/vortex_7b_config.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vortex-7B model configuration.
3
+ Optimized for 8GB VRAM (4060 laptop) and MacBook Pro M2/M3.
4
+ """
5
+
6
+ VORTEX_7B_CONFIG = {
7
+ # Model dimensions
8
+ "d_model": 4096,
9
+ "num_layers": 32,
10
+ "num_heads": 32,
11
+ "head_dim": 128, # d_model // num_heads
12
+
13
+ # State-space layer parameters
14
+ "d_state": 16, # SSM state dimension
15
+ "d_conv": 4, # SSM convolution width
16
+
17
+ # Attention parameters
18
+ "window_size": 512, # Local attention window
19
+ "use_flash_attention": True, # CUDA only
20
+
21
+ # Feed-forward parameters
22
+ "ffn_expansion": 4, # Hidden dim = d_model * expansion
23
+ "num_domains": 7, # Physics, Math, Chemistry, Biology, Earth, Space, Zoology
24
+
25
+ # Tokenizer parameters
26
+ "vocab_size": 50000,
27
+ "max_seq_len": 16384,
28
+
29
+ # Layer ratio: 60% SSM, 40% attention
30
+ "ssm_ratio": 0.6,
31
+
32
+ # Data types
33
+ "dtype": "bfloat16",
34
+
35
+ # Special tokens
36
+ "special_tokens": {
37
+ "[PAD]": 0,
38
+ "[UNK]": 1,
39
+ "[BOS]": 2,
40
+ "[EOS]": 3,
41
+ "[EQUATION]": 4,
42
+ "[/EQUATION]": 5,
43
+ "[CITATION]": 6,
44
+ "[/CITATION]": 7,
45
+ "[MOLECULE]": 8,
46
+ "[/MOLECULE]": 9,
47
+ "[FIGURE]": 10,
48
+ "[TABLE]": 11,
49
+ "[MATH]": 12,
50
+ "[CHEM]": 13,
51
+ "[BIO]": 14,
52
+ "[PHYS]": 15,
53
+ "[EARTH]": 16,
54
+ "[SPACE]": 17,
55
+ "[ZOO]": 18,
56
+ },
57
+
58
+ # Domain tags
59
+ "domain_tags": ["[MATH]", "[CHEM]", "[BIO]", "[PHYS]", "[EARTH]", "[SPACE]", "[ZOO]"],
60
+
61
+ # Science module flags (enable/disable for ablation)
62
+ "enable_equation_module": True,
63
+ "enable_numerical_module": True,
64
+ "enable_citation_module": True,
65
+ "enable_molecular_module": True,
66
+ }
67
+
68
+
69
+ def get_config():
70
+ """Return the 7B configuration dictionary."""
71
+ return VORTEX_7B_CONFIG
configuration_vortex.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vortex configuration for HuggingFace.
3
+ """
4
+
5
+ from typing import Optional, List, Dict, Any
6
+ from transformers import PretrainedConfig
7
+
8
+
9
+ class VortexConfig(PretrainedConfig):
10
+ """
11
+ Configuration class for Vortex model.
12
+ Compatible with HuggingFace transformers.
13
+ """
14
+
15
+ model_type = "vortex"
16
+ tie_word_embeddings = True
17
+
18
+ def __init__(
19
+ self,
20
+ d_model: int = 4096,
21
+ num_layers: int = 32,
22
+ num_heads: int = 32,
23
+ d_state: int = 16,
24
+ d_conv: int = 4,
25
+ window_size: int = 512,
26
+ ffn_expansion: int = 4,
27
+ num_domains: int = 7,
28
+ vocab_size: int = 50000,
29
+ max_seq_len: int = 16384,
30
+ ssm_ratio: float = 0.6,
31
+ enable_equation_module: bool = True,
32
+ enable_numerical_module: bool = True,
33
+ enable_citation_module: bool = True,
34
+ enable_molecular_module: bool = True,
35
+ special_tokens: Optional[Dict[str, int]] = None,
36
+ domain_tags: Optional[List[str]] = None,
37
+ initializer_range: float = 0.02,
38
+ tie_word_embeddings: bool = True,
39
+ **kwargs
40
+ ):
41
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
42
+ self.d_model = d_model
43
+ self.num_layers = num_layers
44
+ self.num_heads = num_heads
45
+ self.d_state = d_state
46
+ self.d_conv = d_conv
47
+ self.window_size = window_size
48
+ self.ffn_expansion = ffn_expansion
49
+ self.num_domains = num_domains
50
+ self.vocab_size = vocab_size
51
+ self.max_seq_len = max_seq_len
52
+ self.ssm_ratio = ssm_ratio
53
+ self.enable_equation_module = enable_equation_module
54
+ self.enable_numerical_module = enable_numerical_module
55
+ self.enable_citation_module = enable_citation_module
56
+ self.enable_molecular_module = enable_molecular_module
57
+ self.special_tokens = special_tokens or {
58
+ "[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3,
59
+ "[EQUATION]": 4, "[/EQUATION]": 5,
60
+ "[CITATION]": 6, "[/CITATION]": 7,
61
+ "[MOLECULE]": 8, "[/MOLECULE]": 9,
62
+ "[FIGURE]": 10, "[TABLE]": 11,
63
+ "[MATH]": 12, "[CHEM]": 13, "[BIO]": 14,
64
+ "[PHYS]": 15, "[EARTH]": 16, "[SPACE]": 17, "[ZOO]": 18,
65
+ }
66
+ self.domain_tags = domain_tags or ["[MATH]", "[CHEM]", "[BIO]", "[PHYS]", "[EARTH]", "[SPACE]", "[ZOO]"]
67
+ self.initializer_range = initializer_range
68
+ # Compute derived attributes
69
+ self.head_dim = self.d_model // self.num_heads
70
+
71
+ @classmethod
72
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
73
+ """Load config from pretrained model."""
74
+ import json
75
+ import os
76
+
77
+ config_path = os.path.join(pretrained_model_name_or_path, "config.json")
78
+ if os.path.exists(config_path):
79
+ with open(config_path, "r") as f:
80
+ config_dict = json.load(f)
81
+ config_dict.update(kwargs)
82
+ return cls(**config_dict)
83
+ else:
84
+ # Return default config
85
+ return cls(**kwargs)
86
+
87
+ def to_dict(self) -> Dict[str, Any]:
88
+ """Convert to dictionary."""
89
+ return {
90
+ "model_type": self.model_type,
91
+ "d_model": self.d_model,
92
+ "num_layers": self.num_layers,
93
+ "num_heads": self.num_heads,
94
+ "head_dim": self.head_dim,
95
+ "d_state": self.d_state,
96
+ "d_conv": self.d_conv,
97
+ "window_size": self.window_size,
98
+ "ffn_expansion": self.ffn_expansion,
99
+ "num_domains": self.num_domains,
100
+ "vocab_size": self.vocab_size,
101
+ "max_seq_len": self.max_seq_len,
102
+ "ssm_ratio": self.ssm_ratio,
103
+ "enable_equation_module": self.enable_equation_module,
104
+ "enable_numerical_module": self.enable_numerical_module,
105
+ "enable_citation_module": self.enable_citation_module,
106
+ "enable_molecular_module": self.enable_molecular_module,
107
+ "special_tokens": self.special_tokens,
108
+ "domain_tags": self.domain_tags,
109
+ "initializer_range": self.initializer_range,
110
+ }
cuda_optimize.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CUDA optimizations for Vortex model on Nvidia 4060 laptop.
3
+ Flash Attention 2, torch.compile, INT8 quantization.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Optional, Dict, Any
9
+
10
+
11
+ def optimize_for_cuda(
12
+ model: nn.Module,
13
+ config: Dict,
14
+ use_flash_attention: bool = True,
15
+ use_torch_compile: bool = True,
16
+ compile_mode: str = "reduce-overhead",
17
+ quantization: Optional[str] = None,
18
+ ) -> nn.Module:
19
+ """
20
+ Apply CUDA optimizations to model.
21
+
22
+ Args:
23
+ model: VortexModel
24
+ config: Model config
25
+ use_flash_attention: Enable Flash Attention 2
26
+ use_torch_compile: Use torch.compile
27
+ compile_mode: Compile mode ("reduce-overhead", "max-autotune")
28
+ quantization: None, "int8", or "int4"
29
+
30
+ Returns:
31
+ Optimized model
32
+ """
33
+ device = torch.device("cuda")
34
+
35
+ # Move to CUDA
36
+ model = model.to(device)
37
+
38
+ # Set dtype
39
+ dtype_str = config.get("dtype", "bfloat16")
40
+ if dtype_str == "bfloat16":
41
+ dtype = torch.bfloat16
42
+ elif dtype_str == "float16":
43
+ dtype = torch.float16
44
+ else:
45
+ dtype = torch.float32
46
+
47
+ model = model.to(dtype)
48
+
49
+ # Apply Flash Attention 2 to attention layers
50
+ if use_flash_attention:
51
+ model = _apply_flash_attention(model)
52
+ print("Applied Flash Attention 2")
53
+
54
+ # Apply torch.compile
55
+ if use_torch_compile:
56
+ model = torch.compile(
57
+ model,
58
+ mode=compile_mode,
59
+ fullgraph=True,
60
+ dynamic=True,
61
+ )
62
+ print(f"Applied torch.compile with mode={compile_mode}")
63
+
64
+ # Apply quantization if requested
65
+ if quantization == "int8":
66
+ model = _apply_int8_quantization(model)
67
+ print("Applied INT8 quantization")
68
+ elif quantization == "int4":
69
+ model = _apply_int4_quantization(model)
70
+ print("Applied INT4 quantization")
71
+
72
+ return model
73
+
74
+
75
+ def _apply_flash_attention(model: nn.Module) -> nn.Module:
76
+ """
77
+ Replace standard attention with Flash Attention 2.
78
+ Requires: pip install flash-attn
79
+ """
80
+ try:
81
+ from flash_attn import flash_attn_func
82
+
83
+ # Monkey-patch attention layers to use flash attention
84
+ for name, module in model.named_modules():
85
+ if hasattr(module, 'use_flash_attention'):
86
+ module.use_flash_attention = True
87
+ # Replace forward with flash attention version
88
+ original_forward = module.forward
89
+
90
+ def flash_forward(self, x, *args, **kwargs):
91
+ return self._flash_attention_forward(x, *args, **kwargs)
92
+
93
+ module.forward = flash_forward.__get__(module, type(module))
94
+
95
+ return model
96
+
97
+ except ImportError:
98
+ print("Flash Attention not available. Install with: pip install flash-attn")
99
+ return model
100
+
101
+
102
+ def _apply_int8_quantization(model: nn.Module) -> nn.Module:
103
+ """
104
+ Apply INT8 quantization using bitsandbytes.
105
+ """
106
+ try:
107
+ import bitsandbytes as bnb
108
+
109
+ # Replace linear layers with 8-bit variants
110
+ for name, module in model.named_modules():
111
+ if isinstance(module, nn.Linear):
112
+ # Create 8-bit linear replacement
113
+ parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
114
+ child_name = name.rsplit('.', 1)[1] if '.' in name else name
115
+
116
+ # Get parent module
117
+ parent = model
118
+ if parent_name:
119
+ for part in parent_name.split('.'):
120
+ parent = getattr(parent, part)
121
+
122
+ # Replace with 8-bit linear
123
+ replacement = bnb.nn.Linear8bitLt(
124
+ module.in_features,
125
+ module.out_features,
126
+ bias=module.bias is not None,
127
+ has_fp16_weights=False,
128
+ )
129
+ # Copy weights (will be quantized)
130
+ replacement.weight.data = module.weight.data
131
+ if module.bias is not None:
132
+ replacement.bias.data = module.bias.data
133
+
134
+ setattr(parent, child_name, replacement)
135
+
136
+ return model
137
+
138
+ except ImportError:
139
+ print("bitsandbytes not available. Install with: pip install bitsandbytes")
140
+ return model
141
+
142
+
143
+ def _apply_int4_quantization(model: nn.Module) -> nn.Module:
144
+ """
145
+ Apply INT4 quantization using bitsandbytes.
146
+ More aggressive, for 13B on 8GB VRAM.
147
+ """
148
+ try:
149
+ import bitsandbytes as bnb
150
+
151
+ for name, module in model.named_modules():
152
+ if isinstance(module, nn.Linear):
153
+ parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
154
+ child_name = name.rsplit('.', 1)[1] if '.' in name else name
155
+
156
+ parent = model
157
+ if parent_name:
158
+ for part in parent_name.split('.'):
159
+ parent = getattr(parent, part)
160
+
161
+ # 4-bit linear
162
+ replacement = bnb.nn.Linear4bit(
163
+ module.in_features,
164
+ module.out_features,
165
+ bias=module.bias is not None,
166
+ compute_dtype=torch.float16,
167
+ compress_statistics=True,
168
+ )
169
+ replacement.weight.data = module.weight.data
170
+ if module.bias is not None:
171
+ replacement.bias.data = module.bias.data
172
+
173
+ setattr(parent, child_name, replacement)
174
+
175
+ return model
176
+
177
+ except ImportError:
178
+ print("bitsandbytes not available.")
179
+ return model
180
+
181
+
182
+ def get_cuda_memory_usage() -> Dict[str, float]:
183
+ """Get current CUDA memory usage in GB."""
184
+ if not torch.cuda.is_available():
185
+ return {"error": "CUDA not available"}
186
+
187
+ allocated = torch.cuda.memory_allocated() / 1e9
188
+ reserved = torch.cuda.memory_reserved() / 1e9
189
+ max_allocated = torch.cuda.max_memory_allocated() / 1e9
190
+
191
+ return {
192
+ "allocated_gb": allocated,
193
+ "reserved_gb": reserved,
194
+ "max_allocated_gb": max_allocated,
195
+ }
196
+
197
+
198
+ def profile_model(
199
+ model: nn.Module,
200
+ input_ids: torch.Tensor,
201
+ num_warmup: int = 10,
202
+ num_runs: int = 100,
203
+ ) -> Dict[str, float]:
204
+ """
205
+ Profile model performance.
206
+
207
+ Args:
208
+ model: Model to profile
209
+ input_ids: Example input
210
+ num_warmup: Number of warmup runs
211
+ num_runs: Number of profiling runs
212
+
213
+ Returns:
214
+ Dictionary with timing statistics
215
+ """
216
+ model.eval()
217
+ device = next(model.parameters()).device
218
+ input_ids = input_ids.to(device)
219
+
220
+ # Warmup
221
+ with torch.no_grad():
222
+ for _ in range(num_warmup):
223
+ _ = model(input_ids)
224
+
225
+ # Profile
226
+ torch.cuda.synchronize()
227
+ import time
228
+ start = time.time()
229
+
230
+ with torch.no_grad():
231
+ for _ in range(num_runs):
232
+ _ = model(input_ids)
233
+
234
+ torch.cuda.synchronize()
235
+ elapsed = time.time() - start
236
+
237
+ avg_time = elapsed / num_runs
238
+ tokens_per_sec = input_ids.shape[1] / avg_time
239
+
240
+ return {
241
+ "avg_time_sec": avg_time,
242
+ "tokens_per_sec": tokens_per_sec,
243
+ }
244
+
245
+
246
+ def test_cuda_optimize():
247
+ """Test CUDA optimizations."""
248
+ if not torch.cuda.is_available():
249
+ print("CUDA not available, skipping test")
250
+ return
251
+
252
+ from models.vortex_model import VortexModel
253
+ from configs.vortex_7b_config import VORTEX_7B_CONFIG
254
+
255
+ config = VORTEX_7B_CONFIG.copy()
256
+ config["d_model"] = 512
257
+ config["num_layers"] = 2
258
+ config["num_heads"] = 8
259
+ config["vocab_size"] = 1000
260
+
261
+ model = VortexModel(config)
262
+ print(f"Model parameters: {model.get_num_params():,}")
263
+
264
+ # Optimize
265
+ model = optimize_for_cuda(
266
+ model,
267
+ config,
268
+ use_flash_attention=False, # May not be available
269
+ use_torch_compile=False, # Skip compile for test
270
+ quantization=None,
271
+ )
272
+
273
+ # Test forward
274
+ batch_size = 2
275
+ seq_len = 128
276
+ input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)).cuda()
277
+
278
+ with torch.no_grad():
279
+ output = model(input_ids)
280
+ logits = output["logits"]
281
+
282
+ print(f"Output shape: {logits.shape}")
283
+ print("CUDA optimize test passed!")
284
+
285
+
286
+ if __name__ == "__main__":
287
+ test_cuda_optimize()
data/__pycache__/deduplication.cpython-313.pyc ADDED
Binary file (10.1 kB). View file
 
data/__pycache__/domain_classifier.cpython-313.pyc ADDED
Binary file (6.38 kB). View file
 
data/__pycache__/quality_filter.cpython-313.pyc ADDED
Binary file (11.3 kB). View file
 
data/dataset_loader.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DatasetLoader: Loads and processes open scientific datasets.
3
+ Supports streaming from HuggingFace datasets with sharding.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ from typing import List, Dict, Optional, Iterator
9
+ from pathlib import Path
10
+
11
+ try:
12
+ from datasets import load_dataset, Dataset, IterableDataset
13
+ import pyarrow.parquet as pq
14
+ except ImportError:
15
+ print("Please install datasets and pyarrow: pip install datasets pyarrow")
16
+ raise
17
+
18
+
19
+ class VortexDatasetLoader:
20
+ """
21
+ Loads and processes open scientific datasets.
22
+ Supports streaming with sharding to Parquet files.
23
+ """
24
+
25
+ # Open datasets configuration
26
+ DATASETS = {
27
+ "pile_scientific": {
28
+ "path": "EleutherAI/pile",
29
+ "subset": "pubmed_central",
30
+ "split": "train",
31
+ "text_field": "text",
32
+ "domain": "biology", # approximate
33
+ },
34
+ "s2orc": {
35
+ "path": "allenai/s2orc",
36
+ "subset": None,
37
+ "split": "train",
38
+ "text_field": "text",
39
+ "domain": "multidisciplinary",
40
+ },
41
+ "pes2o": {
42
+ "path": "allenai/peS2o",
43
+ "subset": None,
44
+ "split": "train",
45
+ "text_field": "text",
46
+ "domain": "multidisciplinary",
47
+ },
48
+ "automath": {
49
+ "path": "math-ai/AutoMathText",
50
+ "subset": None,
51
+ "split": "train",
52
+ "text_field": "text",
53
+ "domain": "math",
54
+ },
55
+ "deepmind_math": {
56
+ "path": "deepmind/math_dataset",
57
+ "subset": "algebra__linear_1d",
58
+ "split": "train",
59
+ "text_field": "text",
60
+ "domain": "math",
61
+ },
62
+ "pubmed_qa": {
63
+ "path": "bigbio/pubmed_qa",
64
+ "subset": "pubmed_qa_labeled_fold0_source",
65
+ "split": "train",
66
+ "text_field": "question",
67
+ "domain": "biology",
68
+ },
69
+ }
70
+
71
+ def __init__(
72
+ self,
73
+ cache_dir: str = "./data/cache",
74
+ output_dir: str = "./data/processed",
75
+ num_proc: int = 4,
76
+ ):
77
+ """
78
+ Initialize dataset loader.
79
+
80
+ Args:
81
+ cache_dir: Directory for caching downloaded datasets
82
+ output_dir: Directory for processed shards
83
+ num_proc: Number of processes for data processing
84
+ """
85
+ self.cache_dir = Path(cache_dir)
86
+ self.output_dir = Path(output_dir)
87
+ self.num_proc = num_proc
88
+
89
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
90
+ self.output_dir.mkdir(parents=True, exist_ok=True)
91
+
92
+ def load_dataset(
93
+ self,
94
+ dataset_name: str,
95
+ streaming: bool = True,
96
+ max_samples: Optional[int] = None,
97
+ ) -> Iterator[Dict]:
98
+ """
99
+ Load a dataset as an iterator.
100
+
101
+ Args:
102
+ dataset_name: Name from DATASETS config
103
+ streaming: Use streaming mode for large datasets
104
+ max_samples: Maximum number of samples to yield
105
+
106
+ Yields:
107
+ Dictionary with text and metadata
108
+ """
109
+ if dataset_name not in self.DATASETS:
110
+ raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(self.DATASETS.keys())}")
111
+
112
+ config = self.DATASETS[dataset_name]
113
+
114
+ print(f"Loading dataset: {dataset_name}")
115
+ print(f" Path: {config['path']}")
116
+ print(f" Streaming: {streaming}")
117
+
118
+ try:
119
+ dataset = load_dataset(
120
+ config["path"],
121
+ name=config["subset"],
122
+ split=config["split"],
123
+ streaming=streaming,
124
+ cache_dir=str(self.cache_dir),
125
+ )
126
+
127
+ count = 0
128
+ for sample in dataset:
129
+ text = sample.get(config["text_field"], "")
130
+ if not text or not isinstance(text, str):
131
+ continue
132
+
133
+ yield {
134
+ "text": text,
135
+ "dataset": dataset_name,
136
+ "domain": config["domain"],
137
+ "source": config["path"],
138
+ }
139
+
140
+ count += 1
141
+ if max_samples and count >= max_samples:
142
+ break
143
+
144
+ print(f"Loaded {count} samples from {dataset_name}")
145
+
146
+ except Exception as e:
147
+ print(f"Error loading dataset {dataset_name}: {e}")
148
+ # Return empty iterator
149
+ return
150
+
151
+ def load_multiple_datasets(
152
+ self,
153
+ dataset_names: List[str],
154
+ streaming: bool = True,
155
+ max_per_dataset: Optional[int] = None,
156
+ ) -> Iterator[Dict]:
157
+ """
158
+ Load multiple datasets and yield samples interleaved.
159
+
160
+ Args:
161
+ dataset_names: List of dataset names
162
+ streaming: Use streaming mode
163
+ max_per_dataset: Max samples per dataset
164
+
165
+ Yields:
166
+ Dictionary with text and metadata
167
+ """
168
+ iterators = []
169
+ for name in dataset_names:
170
+ it = self.load_dataset(name, streaming=streaming, max_samples=max_per_dataset)
171
+ iterators.append(it)
172
+
173
+ # Simple round-robin interleaving
174
+ active = len(iterators)
175
+ while active > 0:
176
+ for i, it in enumerate(iterators):
177
+ if it is None:
178
+ continue
179
+ try:
180
+ yield next(it)
181
+ except StopIteration:
182
+ iterators[i] = None
183
+ active -= 1
184
+ break
185
+
186
+ def shard_to_parquet(
187
+ self,
188
+ samples: Iterator[Dict],
189
+ output_prefix: str,
190
+ samples_per_shard: int = 10000,
191
+ ):
192
+ """
193
+ Write samples to sharded Parquet files.
194
+
195
+ Args:
196
+ samples: Iterator of sample dictionaries
197
+ output_prefix: Prefix for output files (e.g., "train")
198
+ samples_per_shard: Number of samples per shard
199
+ """
200
+ shard_index = 0
201
+ buffer = []
202
+
203
+ for sample in samples:
204
+ buffer.append(sample)
205
+
206
+ if len(buffer) >= samples_per_shard:
207
+ self._write_shard(buffer, output_prefix, shard_index)
208
+ shard_index += 1
209
+ buffer = []
210
+
211
+ # Write remaining
212
+ if buffer:
213
+ self._write_shard(buffer, output_prefix, shard_index)
214
+
215
+ print(f"Wrote {shard_index + 1} shards to {self.output_dir}")
216
+
217
+ def _write_shard(
218
+ self,
219
+ buffer: List[Dict],
220
+ output_prefix: str,
221
+ shard_index: int,
222
+ ):
223
+ """Write a single shard to Parquet."""
224
+ import pandas as pd
225
+
226
+ df = pd.DataFrame(buffer)
227
+ output_path = self.output_dir / f"{output_prefix}_{shard_index:05d}.parquet"
228
+ df.to_parquet(output_path, index=False)
229
+
230
+ def get_shard_list(
231
+ self,
232
+ prefix: str,
233
+ ) -> List[Path]:
234
+ """Get list of shard files matching prefix."""
235
+ return sorted(self.output_dir.glob(f"{prefix}_*.parquet"))
236
+
237
+ def read_shard(
238
+ self,
239
+ shard_path: Path,
240
+ ) -> List[Dict]:
241
+ """Read a single shard."""
242
+ import pandas as pd
243
+ df = pd.read_parquet(shard_path)
244
+ return df.to_dict('records')
245
+
246
+
247
+ def test_dataset_loader():
248
+ """Test the dataset loader."""
249
+ loader = VortexDatasetLoader()
250
+
251
+ # Test loading a small dataset
252
+ print("Testing dataset loader...")
253
+ count = 0
254
+ for sample in loader.load_dataset("pubmed_qa", streaming=True, max_samples=10):
255
+ print(f"Sample {count}: {sample['text'][:100]}...")
256
+ count += 1
257
+
258
+ print(f"Loaded {count} samples")
259
+ print("DatasetLoader test passed!")
260
+
261
+
262
+ if __name__ == "__main__":
263
+ test_dataset_loader()
data/deduplication.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Deduplication: MinHash LSH for near-duplicate detection.
3
+ """
4
+
5
+ import hashlib
6
+ import random
7
+ from typing import List, Set, Tuple, Optional
8
+ from dataclasses import dataclass
9
+
10
+
11
+ @dataclass
12
+ class MinHashSignature:
13
+ """MinHash signature for a document."""
14
+ hash_values: List[int]
15
+ doc_id: str
16
+
17
+
18
+ class MinHashLSH:
19
+ """
20
+ MinHash LSH for near-duplicate detection.
21
+ Uses shingling and MinHash to estimate Jaccard similarity.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ num_permutations: int = 128,
27
+ threshold: float = 0.8,
28
+ bands: int = 16,
29
+ rows_per_band: int = 8,
30
+ ):
31
+ """
32
+ Initialize MinHash LSH.
33
+
34
+ Args:
35
+ num_permutations: Number of hash permutations for MinHash
36
+ threshold: Similarity threshold for considering duplicates
37
+ bands: Number of bands for LSH
38
+ rows_per_band: Rows per band (bands * rows_per_band = num_permutations)
39
+ """
40
+ self.num_permutations = num_permutations
41
+ self.threshold = threshold
42
+ self.bands = bands
43
+ self.rows_per_band = rows_per_band
44
+
45
+ assert bands * rows_per_band == num_permutations
46
+
47
+ # Generate random hash functions
48
+ self.hash_functions = self._generate_hash_functions(num_permutations)
49
+
50
+ # LSH index: band_id -> {bucket_hash -> set of doc_ids}
51
+ self.index = [dict() for _ in range(bands)]
52
+
53
+ # Store signatures for similarity computation
54
+ self.signatures = {} # doc_id -> MinHashSignature
55
+
56
+ def _generate_hash_functions(self, n: int) -> List:
57
+ """Generate n random hash functions."""
58
+ # Use random permutations of large prime
59
+ functions = []
60
+ for _ in range(n):
61
+ a = random.randint(1, 2**32 - 1)
62
+ b = random.randint(0, 2**32 - 1)
63
+ functions.append((a, b))
64
+ return functions
65
+
66
+ def _hash(self, x: int, a: int, b: int) -> int:
67
+ """Universal hash function: (a*x + b) mod prime."""
68
+ prime = 2**61 - 1
69
+ return ((a * x + b) % prime) & 0xFFFFFFFF
70
+
71
+ def _compute_minhash(self, shingles: Set[int]) -> List[int]:
72
+ """
73
+ Compute MinHash signature for a set of shingles.
74
+
75
+ Args:
76
+ shingles: Set of shingle hash values
77
+
78
+ Returns:
79
+ List of minhash values (one per permutation)
80
+ """
81
+ signature = []
82
+ for a, b in self.hash_functions:
83
+ min_hash = min(self._hash(shingle, a, b) for shingle in shingles)
84
+ signature.append(min_hash)
85
+ return signature
86
+
87
+ def _shingle_text(
88
+ self,
89
+ text: str,
90
+ k: int = 5,
91
+ ) -> Set[int]:
92
+ """
93
+ Extract k-gram shingles from text.
94
+
95
+ Args:
96
+ text: Input text
97
+ k: Shingle size (characters)
98
+
99
+ Returns:
100
+ Set of shingle hashes
101
+ """
102
+ text = text.lower()
103
+ shingles = set()
104
+ for i in range(len(text) - k + 1):
105
+ shingle = text[i:i+k]
106
+ # Hash shingle
107
+ shingle_hash = int(hashlib.md5(shingle.encode()).hexdigest()[:8], 16)
108
+ shingles.add(shingle_hash)
109
+ return shingles
110
+
111
+ def add_document(
112
+ self,
113
+ doc_id: str,
114
+ text: str,
115
+ compute_signature: bool = True,
116
+ ) -> MinHashSignature:
117
+ """
118
+ Add a document to the LSH index.
119
+
120
+ Args:
121
+ doc_id: Unique document ID
122
+ text: Document text
123
+ compute_signature: Whether to compute signature (or use precomputed)
124
+
125
+ Returns:
126
+ MinHash signature
127
+ """
128
+ if compute_signature:
129
+ shingles = self._shingle_text(text)
130
+ signature = self._compute_minhash(shingles)
131
+ else:
132
+ raise ValueError("Must compute signature")
133
+
134
+ # Store signature
135
+ signature_obj = MinHashSignature(hash_values=signature, doc_id=doc_id)
136
+ self.signatures[doc_id] = signature_obj
137
+
138
+ # Index into bands
139
+ for band_idx in range(self.bands):
140
+ start = band_idx * self.rows_per_band
141
+ end = start + self.rows_per_band
142
+ band_signature = tuple(signature[start:end])
143
+ bucket_hash = hash(band_signature)
144
+
145
+ if bucket_hash not in self.index[band_idx]:
146
+ self.index[band_idx][bucket_hash] = set()
147
+ self.index[band_idx][bucket_hash].add(doc_id)
148
+
149
+ return signature_obj
150
+
151
+ def query(
152
+ self,
153
+ text: str,
154
+ candidate_doc_ids: Optional[Set[str]] = None,
155
+ ) -> List[Tuple[str, float]]:
156
+ """
157
+ Query for near-duplicate documents.
158
+
159
+ Args:
160
+ text: Query text
161
+ candidate_doc_ids: Optional set of candidate doc IDs to check
162
+
163
+ Returns:
164
+ List of (doc_id, similarity) above threshold
165
+ """
166
+ shingles = self._shingle_text(text)
167
+ query_signature = self._compute_minhash(shingles)
168
+
169
+ # Find candidates via LSH
170
+ candidate_sets = []
171
+ for band_idx in range(self.bands):
172
+ start = band_idx * self.rows_per_band
173
+ end = start + self.rows_per_band
174
+ band_signature = tuple(query_signature[start:end])
175
+ bucket_hash = hash(band_signature)
176
+
177
+ if bucket_hash in self.index[band_idx]:
178
+ candidate_sets.append(self.index[band_idx][bucket_hash])
179
+
180
+ # Union of candidates
181
+ candidates = set()
182
+ for s in candidate_sets:
183
+ candidates.update(s)
184
+
185
+ if candidate_doc_ids is not None:
186
+ candidates = candidates.intersection(candidate_doc_ids)
187
+
188
+ # Compute exact Jaccard similarity for candidates
189
+ results = []
190
+ query_shingles = shingles
191
+ for doc_id in candidates:
192
+ # In practice, would retrieve stored shingles
193
+ # For now, approximate using signature
194
+ similarity = self._estimate_similarity(query_signature, doc_id)
195
+ if similarity >= self.threshold:
196
+ results.append((doc_id, similarity))
197
+
198
+ return sorted(results, key=lambda x: x[1], reverse=True)
199
+
200
+ def _estimate_similarity(
201
+ self,
202
+ signature1: List[int],
203
+ doc_id2: str,
204
+ ) -> float:
205
+ """
206
+ Estimate Jaccard similarity between two signatures.
207
+ Uses MinHash similarity: proportion of matching hash values.
208
+
209
+ Args:
210
+ signature1: First MinHash signature
211
+ doc_id2: Second document ID (signature retrieved from storage)
212
+
213
+ Returns:
214
+ Estimated Jaccard similarity
215
+ """
216
+ if doc_id2 not in self.signatures:
217
+ return 0.0
218
+
219
+ signature2 = self.signatures[doc_id2].hash_values
220
+
221
+ # Count matching values
222
+ matches = sum(1 for h1, h2 in zip(signature1, signature2) if h1 == h2)
223
+ similarity = matches / len(signature1)
224
+
225
+ return similarity
226
+
227
+ def compute_signature(self, text: str) -> MinHashSignature:
228
+ """Compute MinHash signature for text."""
229
+ shingles = self._shingle_text(text)
230
+ signature = self._compute_minhash(shingles)
231
+ return MinHashSignature(hash_values=signature, doc_id="")
232
+
233
+
234
+ def test_deduplication():
235
+ """Test MinHash LSH."""
236
+ lsh = MinHashLSH(num_permutations=64, threshold=0.7, bands=8, rows_per_band=8)
237
+
238
+ # Add documents
239
+ docs = [
240
+ ("doc1", "The quick brown fox jumps over the lazy dog."),
241
+ ("doc2", "The quick brown fox jumps over the lazy dog!!!"), # near duplicate
242
+ ("doc3", "The quick brown fox leaps over the lazy dog."), # near duplicate
243
+ ("doc4", "Completely unrelated text about science and experiments."),
244
+ ]
245
+
246
+ signatures = {}
247
+ for doc_id, text in docs:
248
+ sig = lsh.add_document(doc_id, text)
249
+ signatures[doc_id] = sig
250
+
251
+ # Query with doc1
252
+ results = lsh.query(docs[0][1])
253
+ print(f"Query results for doc1: {results}")
254
+ # Should find doc2 and doc3 as similar
255
+
256
+ print("Deduplication test passed!")
257
+
258
+
259
+ if __name__ == "__main__":
260
+ test_deduplication()
data/domain_classifier.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DomainClassifier: Classifies documents into 7 science domains.
3
+ Uses a simple linear classifier on top of text features.
4
+ """
5
+
6
+ import re
7
+ from typing import List, Tuple, Optional
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ class DomainClassifier(nn.Module):
13
+ """
14
+ Classifies documents into 7 science domains:
15
+ 0: Physics
16
+ 1: Mathematics
17
+ 2: Chemistry
18
+ 3: Biology
19
+ 4: Earth Science
20
+ 5: Space Science
21
+ 6: Zoology
22
+ """
23
+
24
+ # Domain keywords for rule-based fallback
25
+ DOMAIN_KEYWORDS = {
26
+ 0: ['physics', 'quantum', 'relativity', 'mechanics', 'thermodynamics', 'electromagnetism'],
27
+ 1: ['mathematics', 'algebra', 'calculus', 'geometry', 'topology', 'proof', 'theorem'],
28
+ 2: ['chemistry', 'molecular', 'reaction', 'compound', 'element', 'organic'],
29
+ 3: ['biology', 'cell', 'gene', 'protein', 'organism', 'evolution'],
30
+ 4: ['earth', 'geology', 'climate', 'ocean', 'atmosphere', 'meteorology'],
31
+ 5: ['space', 'astronomy', 'planet', 'star', 'galaxy', 'cosmology'],
32
+ 6: ['zoology', 'animal', 'species', 'vertebrate', 'invertebrate', 'ecology'],
33
+ }
34
+
35
+ def __init__(self, d_model: int, num_domains: int = 7):
36
+ """
37
+ Initialize domain classifier.
38
+
39
+ Args:
40
+ d_model: Input embedding dimension
41
+ num_domains: Number of domains (7)
42
+ """
43
+ super().__init__()
44
+ self.d_model = d_model
45
+ self.num_domains = num_domains
46
+
47
+ # Simple linear classifier
48
+ self.classifier = nn.Linear(d_model, num_domains)
49
+
50
+ # Initialize weights
51
+ nn.init.normal_(self.classifier.weight, mean=0.0, std=0.02)
52
+ nn.init.zeros_(self.classifier.bias)
53
+
54
+ def forward(
55
+ self,
56
+ hidden_states: torch.Tensor,
57
+ attention_mask: Optional[torch.Tensor] = None,
58
+ ) -> torch.Tensor:
59
+ """
60
+ Classify domain from hidden states.
61
+
62
+ Args:
63
+ hidden_states: (batch, seq_len, d_model)
64
+ attention_mask: (batch, seq_len)
65
+
66
+ Returns:
67
+ Domain logits (batch, num_domains)
68
+ """
69
+ # Mean pooling over sequence (masked)
70
+ if attention_mask is not None:
71
+ mask = attention_mask.unsqueeze(-1) # (batch, seq_len, 1)
72
+ summed = (hidden_states * mask).sum(dim=1)
73
+ counts = mask.sum(dim=1)
74
+ pooled = summed / counts.clamp(min=1)
75
+ else:
76
+ pooled = hidden_states.mean(dim=1)
77
+
78
+ # Classify
79
+ logits = self.classifier(pooled)
80
+ return logits
81
+
82
+ def classify_text(
83
+ self,
84
+ text: str,
85
+ ) -> Tuple[int, float]:
86
+ """
87
+ Rule-based fallback classification from raw text.
88
+
89
+ Args:
90
+ text: Input text string
91
+
92
+ Returns:
93
+ (domain_id, confidence)
94
+ """
95
+ text_lower = text.lower()
96
+
97
+ # Count keyword matches per domain
98
+ scores = []
99
+ for domain_id, keywords in self.DOMAIN_KEYWORDS.items():
100
+ score = sum(1 for kw in keywords if kw in text_lower)
101
+ scores.append(score)
102
+
103
+ if max(scores) == 0:
104
+ return 0, 0.0 # Unknown -> default to physics
105
+
106
+ best_domain = scores.index(max(scores))
107
+ confidence = max(scores) / sum(scores) if sum(scores) > 0 else 0.0
108
+
109
+ return best_domain, confidence
110
+
111
+ def compute_loss(
112
+ self,
113
+ logits: torch.Tensor,
114
+ domain_labels: torch.Tensor,
115
+ ) -> torch.Tensor:
116
+ """
117
+ Compute classification loss.
118
+
119
+ Args:
120
+ logits: (batch, num_domains)
121
+ domain_labels: (batch,) with domain IDs
122
+
123
+ Returns:
124
+ Cross-entropy loss
125
+ """
126
+ return nn.functional.cross_entropy(logits, domain_labels)
127
+
128
+
129
+ def test_domain_classifier():
130
+ """Test DomainClassifier."""
131
+ d_model = 512
132
+ batch_size = 4
133
+ seq_len = 128
134
+
135
+ classifier = DomainClassifier(d_model)
136
+
137
+ # Test with random hidden states
138
+ hidden = torch.randn(batch_size, seq_len, d_model)
139
+ logits = classifier(hidden)
140
+ print(f"Logits shape: {logits.shape}")
141
+ assert logits.shape == (batch_size, 7)
142
+
143
+ # Test with text
144
+ texts = [
145
+ "The quantum mechanics of particles...",
146
+ "Solving differential equations...",
147
+ "Chemical reactions produce compounds...",
148
+ "Cells contain DNA and proteins...",
149
+ ]
150
+ for text in texts:
151
+ domain, conf = classifier.classify_text(text)
152
+ print(f"Text: {text[:30]}... -> Domain {domain}, conf {conf:.2f}")
153
+
154
+ # Test loss
155
+ labels = torch.tensor([0, 1, 2, 3])
156
+ loss = classifier.compute_loss(logits, labels)
157
+ print(f"Loss: {loss.item():.4f}")
158
+
159
+ print("DomainClassifier test passed!")
160
+
161
+
162
+ if __name__ == "__main__":
163
+ test_domain_classifier()
data/quality_filter.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ScienceQualityFilter: Multi-stage quality filtering for scientific text.
3
+ """
4
+
5
+ import re
6
+ from typing import List, Tuple, Optional
7
+ from dataclasses import dataclass
8
+
9
+
10
+ @dataclass
11
+ class FilterStats:
12
+ """Statistics from quality filtering."""
13
+ total: int = 0
14
+ passed: int = 0
15
+ failed_length: int = 0
16
+ failed_language: int = 0
17
+ failed_content: int = 0
18
+ failed_equations: int = 0
19
+ failed_repetition: int = 0
20
+ failed_citations: int = 0
21
+
22
+
23
+ class ScienceQualityFilter:
24
+ """
25
+ Multi-stage quality filtering for scientific text.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ min_length: int = 128,
31
+ max_length: int = 64000,
32
+ max_repetition_ratio: float = 0.2,
33
+ min_equation_ratio: float = 0.0,
34
+ max_equation_ratio: float = 0.6,
35
+ min_citation_ratio: float = 0.0,
36
+ max_citation_ratio: float = 0.4,
37
+ ):
38
+ """
39
+ Initialize quality filter.
40
+
41
+ Args:
42
+ min_length: Minimum text length in characters
43
+ max_length: Maximum text length in characters
44
+ max_repetition_ratio: Maximum character-level repetition ratio
45
+ min_equation_ratio: Minimum equation density (optional)
46
+ max_equation_ratio: Maximum equation density
47
+ min_citation_ratio: Minimum citation density (optional)
48
+ max_citation_ratio: Maximum citation density
49
+ """
50
+ self.min_length = min_length
51
+ self.max_length = max_length
52
+ self.max_repetition_ratio = max_repetition_ratio
53
+ self.min_equation_ratio = min_equation_ratio
54
+ self.max_equation_ratio = max_equation_ratio
55
+ self.min_citation_ratio = min_citation_ratio
56
+ self.max_citation_ratio = max_citation_ratio
57
+
58
+ def filter(self, text: str, stats: Optional[FilterStats] = None) -> bool:
59
+ """
60
+ Run all quality checks on text.
61
+
62
+ Args:
63
+ text: Input text
64
+ stats: Optional stats object to update
65
+
66
+ Returns:
67
+ True if text passes all checks, False otherwise
68
+ """
69
+ if stats is None:
70
+ stats = FilterStats()
71
+ stats.total += 1
72
+
73
+ # 1. Length check
74
+ if not self.length_check(text):
75
+ stats.failed_length += 1
76
+ return False
77
+
78
+ # 2. Language check (English only, simplified)
79
+ if not self.language_check(text):
80
+ stats.failed_language += 1
81
+ return False
82
+
83
+ # 3. Science content check
84
+ if not self.science_content_check(text):
85
+ stats.failed_content += 1
86
+ return False
87
+
88
+ # 4. Equation validity check
89
+ if not self.equation_validity_check(text):
90
+ stats.failed_equations += 1
91
+ return False
92
+
93
+ # 5. Repetition check
94
+ if not self.repetition_check(text):
95
+ stats.failed_repetition += 1
96
+ return False
97
+
98
+ # 6. Citation ratio check
99
+ if not self.citation_ratio_check(text):
100
+ stats.failed_citations += 1
101
+ return False
102
+
103
+ stats.passed += 1
104
+ return True
105
+
106
+ def length_check(self, text: str) -> bool:
107
+ """Check text length."""
108
+ length = len(text)
109
+ return self.min_length <= length <= self.max_length
110
+
111
+ def language_check(self, text: str) -> bool:
112
+ """
113
+ Check if text is likely English.
114
+ Simplified heuristic: count common English words.
115
+ """
116
+ # Common English words
117
+ english_words = {'the', 'be', 'to', 'of', 'and', 'a', 'in', 'that', 'have', 'i'}
118
+ words = re.findall(r'\b[a-zA-Z]{2,}\b', text.lower())
119
+ if len(words) < 10:
120
+ return False
121
+
122
+ english_count = sum(1 for w in words if w in english_words)
123
+ return english_count / len(words) > 0.1
124
+
125
+ def science_content_check(self, text: str) -> bool:
126
+ """
127
+ Check if text contains scientific content.
128
+ Looks for scientific terms, equations, units, etc.
129
+ """
130
+ # Scientific keywords
131
+ sci_keywords = [
132
+ 'experiment', 'data', 'result', 'analysis', 'method',
133
+ 'theory', 'hypothesis', 'conclusion', 'discussion',
134
+ 'figure', 'table', 'equation', 'reference', 'citation',
135
+ 'molecular', 'protein', 'gene', 'cell', 'reaction',
136
+ 'mathematical', 'derivation', 'proof', 'theorem',
137
+ ]
138
+
139
+ # Units
140
+ units = ['m', 'kg', 's', 'mol', 'K', 'J', 'N', 'Pa', 'Hz', 'eV', '°C']
141
+
142
+ # Count occurrences
143
+ text_lower = text.lower()
144
+ keyword_count = sum(1 for kw in sci_keywords if kw in text_lower)
145
+ unit_count = sum(1 for u in units if f' {u}' in text_lower or f'{u} ' in text_lower)
146
+
147
+ return keyword_count >= 2 or unit_count >= 1
148
+
149
+ def equation_validity_check(self, text: str) -> bool:
150
+ """
151
+ Check if LaTeX equations are well-formed.
152
+ """
153
+ # Count dollar signs
154
+ dollar_count = text.count('$')
155
+ if dollar_count % 2 != 0:
156
+ return False # Unmatched dollar signs
157
+
158
+ # Count backslash brackets
159
+ lbracket = text.count('\\[')
160
+ rbracket = text.count('\\]')
161
+ if lbracket != rbracket:
162
+ return False
163
+
164
+ # Count parentheses in inline math
165
+ lparen = text.count('\\(')
166
+ rparen = text.count('\\)')
167
+ if lparen != rparen:
168
+ return False
169
+
170
+ # Check for common LaTeX errors
171
+ # Unmatched braces
172
+ brace_balance = 0
173
+ for char in text:
174
+ if char == '{':
175
+ brace_balance += 1
176
+ elif char == '}':
177
+ brace_balance -= 1
178
+ if brace_balance < 0:
179
+ return False # Closing without opening
180
+
181
+ if brace_balance != 0:
182
+ return False # Unmatched braces
183
+
184
+ return True
185
+
186
+ def repetition_check(self, text: str) -> bool:
187
+ """
188
+ Check for excessive repetition.
189
+ Uses character-level 4-gram repetition.
190
+ """
191
+ if len(text) < 100:
192
+ return True
193
+
194
+ # Get 4-grams
195
+ n = 4
196
+ ngrams = [text[i:i+n] for i in range(len(text) - n + 1)]
197
+
198
+ # Count repetitions
199
+ from collections import Counter
200
+ ngram_counts = Counter(ngrams)
201
+ total_ngrams = len(ngrams)
202
+ unique_ngrams = len(ngram_counts)
203
+
204
+ if total_ngrams == 0:
205
+ return True
206
+
207
+ repetition_ratio = 1 - (unique_ngrams / total_ngrams)
208
+ return repetition_ratio <= self.max_repetition_ratio
209
+
210
+ def citation_ratio_check(self, text: str) -> bool:
211
+ """
212
+ Check if citation density is reasonable.
213
+ """
214
+ # Count citation patterns
215
+ # (Author, Year)
216
+ inline1 = len(re.findall(r'\([A-Za-z\s]+,?\s*\d{4}\)', text))
217
+ # [1] or [1-3]
218
+ inline2 = len(re.findall(r'\[\d+(?:[-,]\d+)*\]', text))
219
+ # [Author, Year]
220
+ inline3 = len(re.findall(r'\[[A-Za-z\s]+,?\s*\d{4}\]', text))
221
+
222
+ total_citations = inline1 + inline2 + inline3
223
+
224
+ # Estimate word count
225
+ words = re.findall(r'\b[a-zA-Z]{2,}\b', text)
226
+ if len(words) == 0:
227
+ return True
228
+
229
+ citation_ratio = total_citations / len(words)
230
+
231
+ # Allow range
232
+ return self.min_citation_ratio <= citation_ratio <= self.max_citation_ratio
233
+
234
+ def get_stats(self, stats: FilterStats) -> str:
235
+ """Get formatted statistics string."""
236
+ total = stats.total if stats.total > 0 else 1
237
+ return (
238
+ f"Quality filter stats:\n"
239
+ f" Total: {stats.total}\n"
240
+ f" Passed: {stats.passed} ({stats.passed/total*100:.1f}%)\n"
241
+ f" Failed - Length: {stats.failed_length}\n"
242
+ f" Failed - Language: {stats.failed_language}\n"
243
+ f" Failed - Content: {stats.failed_content}\n"
244
+ f" Failed - Equations: {stats.failed_equations}\n"
245
+ f" Failed - Repetition: {stats.failed_repetition}\n"
246
+ f" Failed - Citations: {stats.failed_citations}"
247
+ )
248
+
249
+
250
+ def test_quality_filter():
251
+ """Test the quality filter."""
252
+ filter = ScienceQualityFilter()
253
+
254
+ # Good sample
255
+ good_text = """
256
+ The experiment was conducted to test the hypothesis. We collected data from
257
+ 100 participants and performed statistical analysis. The results show a
258
+ significant effect (p < 0.05). According to Smith et al., this confirms
259
+ the theoretical prediction. The equation E = mc^2 is fundamental.
260
+ """
261
+ print(f"Good text passes: {filter.filter(good_text)}")
262
+
263
+ # Bad sample (too short)
264
+ short_text = "This is too short."
265
+ print(f"Short text passes: {filter.filter(short_text)}")
266
+
267
+ # Bad sample (unmatched equations)
268
+ bad_eq = "Here is an equation $E = mc^2 and another $F = ma."
269
+ print(f"Unmatched $ passes: {filter.filter(bad_eq)}")
270
+
271
+ # Bad sample (excessive repetition)
272
+ repetitive = "test test test test test test test test test test " * 100
273
+ print(f"Repetitive passes: {filter.filter(repetitive)}")
274
+
275
+ print("QualityFilter test passed!")
276
+
277
+
278
+ if __name__ == "__main__":
279
+ test_quality_filter()
data/scraper.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VortexScienceScraper: Scrapes scientific content from open access sources.
3
+ Respects robots.txt and rate limits.
4
+ """
5
+
6
+ import time
7
+ import requests
8
+ from typing import List, Dict, Optional
9
+ from urllib.robotparser import RobotFileParser
10
+ from pathlib import Path
11
+ import json
12
+
13
+
14
+ class VortexScienceScraper:
15
+ """
16
+ Scrapes scientific content from open access sources.
17
+ Sources: arXiv, PubMed Central, Wikipedia, NIST, NASA.
18
+ """
19
+
20
+ SOURCES = {
21
+ "arxiv": {
22
+ "base_url": "https://arxiv.org",
23
+ "search_url": "https://arxiv.org/search/",
24
+ "rate_limit": 1.0, # seconds between requests
25
+ "robots": "https://arxiv.org/robots.txt",
26
+ },
27
+ "pubmed": {
28
+ "base_url": "https://www.ncbi.nlm.nih.gov/pmc",
29
+ "search_url": "https://www.ncbi.nlm.nih.gov/pmc/articles/",
30
+ "rate_limit": 0.5,
31
+ "robots": "https://www.ncbi.nlm.nih.gov/robots.txt",
32
+ },
33
+ "wikipedia": {
34
+ "base_url": "https://en.wikipedia.org",
35
+ "search_url": "https://en.wikipedia.org/w/api.php",
36
+ "rate_limit": 0.1,
37
+ "robots": "https://en.wikipedia.org/robots.txt",
38
+ },
39
+ "nist": {
40
+ "base_url": "https://webbook.nist.gov",
41
+ "search_url": "https://webbook.nist.gov/cgi/cbook.cgi",
42
+ "rate_limit": 1.0,
43
+ "robots": "https://webbook.nist.gov/robots.txt",
44
+ },
45
+ "nasa": {
46
+ "base_url": "https://ntrs.nasa.gov",
47
+ "search_url": "https://ntrs.nasa.gov/api/citations/search",
48
+ "rate_limit": 1.0,
49
+ "robots": "https://ntrs.nasa.gov/robots.txt",
50
+ },
51
+ }
52
+
53
+ def __init__(
54
+ self,
55
+ output_dir: str = "./data/scraped",
56
+ respect_robots: bool = True,
57
+ user_agent: str = "VortexScientificBot/1.0",
58
+ ):
59
+ """
60
+ Initialize scraper.
61
+
62
+ Args:
63
+ output_dir: Directory to save scraped data
64
+ respect_robots: Whether to respect robots.txt
65
+ user_agent: User agent string for requests
66
+ """
67
+ self.output_dir = Path(output_dir)
68
+ self.output_dir.mkdir(parents=True, exist_ok=True)
69
+ self.respect_robots = respect_robots
70
+ self.user_agent = user_agent
71
+
72
+ self.session = requests.Session()
73
+ self.session.headers.update({"User-Agent": user_agent})
74
+
75
+ # Cache for robots.txt
76
+ self.robots_cache = {}
77
+
78
+ # Rate limit tracking
79
+ self.last_request_time = {}
80
+
81
+ def _check_robots_allowed(self, url: str) -> bool:
82
+ """Check if robots.txt allows scraping the URL."""
83
+ if not self.respect_robots:
84
+ return True
85
+
86
+ # Extract base domain
87
+ from urllib.parse import urlparse
88
+ parsed = urlparse(url)
89
+ base_url = f"{parsed.scheme}://{parsed.netloc}"
90
+
91
+ if base_url not in self.robots_cache:
92
+ rp = RobotFileParser()
93
+ rp.set_url(base_url + "/robots.txt")
94
+ try:
95
+ rp.read()
96
+ self.robots_cache[base_url] = rp
97
+ except Exception as e:
98
+ print(f"Could not read robots.txt for {base_url}: {e}")
99
+ return False
100
+
101
+ rp = self.robots_cache[base_url]
102
+ return rp.can_fetch(self.user_agent, url)
103
+
104
+ def _rate_limit(self, source: str):
105
+ """Enforce rate limiting for a source."""
106
+ now = time.time()
107
+ last = self.last_request_time.get(source, 0)
108
+ delay = self.SOURCES[source]["rate_limit"]
109
+ if now - last < delay:
110
+ time.sleep(delay - (now - last))
111
+ self.last_request_time[source] = time.time()
112
+
113
+ def scrape_arxiv(
114
+ self,
115
+ query: str,
116
+ max_results: int = 100,
117
+ categories: Optional[List[str]] = None,
118
+ ) -> List[Dict]:
119
+ """
120
+ Scrape arXiv papers.
121
+
122
+ Args:
123
+ query: Search query
124
+ max_results: Maximum number of results
125
+ categories: Optional list of arXiv categories (e.g., ['physics', 'math'])
126
+
127
+ Returns:
128
+ List of paper metadata and abstracts
129
+ """
130
+ papers = []
131
+
132
+ params = {
133
+ "query": query,
134
+ "searchtype": "all",
135
+ "abstracts": "show",
136
+ "size": min(max_results, 200), # arXiv max per page
137
+ "order": "-announced_date_first",
138
+ }
139
+
140
+ if categories:
141
+ params["filter"] = "categories:" + "+OR+".join(categories)
142
+
143
+ url = self.SOURCES["arxiv"]["search_url"]
144
+
145
+ if not self._check_robots_allowed(url):
146
+ print(f"Robots.txt disallows scraping {url}")
147
+ return papers
148
+
149
+ try:
150
+ self._rate_limit("arxiv")
151
+ response = self.session.get(url, params=params)
152
+ response.raise_for_status()
153
+
154
+ # Parse HTML (simplified - would use BeautifulSoup in practice)
155
+ # For now, return placeholder
156
+ print(f"Scraped arXiv query '{query}' - got response status {response.status_code}")
157
+
158
+ # Placeholder: would extract paper titles, abstracts, PDF links
159
+ for i in range(min(10, max_results)):
160
+ papers.append({
161
+ "source": "arxiv",
162
+ "title": f"Paper {i}",
163
+ "abstract": "Abstract placeholder...",
164
+ "pdf_url": f"https://arxiv.org/pdf/{i}.pdf",
165
+ })
166
+
167
+ except Exception as e:
168
+ print(f"Error scraping arXiv: {e}")
169
+
170
+ return papers
171
+
172
+ def scrape_pubmed(
173
+ self,
174
+ query: str,
175
+ max_results: int = 100,
176
+ ) -> List[Dict]:
177
+ """Scrape PubMed Central articles."""
178
+ articles = []
179
+
180
+ # PubMed API endpoint
181
+ url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
182
+ params = {
183
+ "db": "pmc",
184
+ "term": query,
185
+ "retmax": max_results,
186
+ "retmode": "json",
187
+ }
188
+
189
+ if not self._check_robots_allowed(url):
190
+ print(f"Robots.txt disallows {url}")
191
+ return articles
192
+
193
+ try:
194
+ self._rate_limit("pubmed")
195
+ response = self.session.get(url, params=params)
196
+ response.raise_for_status()
197
+
198
+ data = response.json()
199
+ pmc_ids = data.get("esearchresult", {}).get("idlist", [])
200
+
201
+ for pmc_id in pmc_ids[:10]: # Limit for demo
202
+ articles.append({
203
+ "source": "pubmed",
204
+ "pmc_id": pmc_id,
205
+ "url": f"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC{pmc_id}/",
206
+ })
207
+
208
+ print(f"Found {len(pmc_ids)} PubMed articles")
209
+
210
+ except Exception as e:
211
+ print(f"Error scraping PubMed: {e}")
212
+
213
+ return articles
214
+
215
+ def scrape_wikipedia(
216
+ self,
217
+ topic: str,
218
+ max_pages: int = 10,
219
+ ) -> List[Dict]:
220
+ """Scrape Wikipedia science articles."""
221
+ pages = []
222
+
223
+ # Wikipedia API
224
+ url = "https://en.wikipedia.org/w/api.php"
225
+ params = {
226
+ "action": "query",
227
+ "format": "json",
228
+ "prop": "extracts",
229
+ "exintro": True,
230
+ "titles": topic,
231
+ "redirects": True,
232
+ }
233
+
234
+ if not self._check_robots_allowed(url):
235
+ print(f"Robots.txt disallows {url}")
236
+ return pages
237
+
238
+ try:
239
+ self._rate_limit("wikipedia")
240
+ response = self.session.get(url, params=params)
241
+ response.raise_for_status()
242
+
243
+ data = response.json()
244
+ pages_data = data.get("query", {}).get("pages", {})
245
+
246
+ for page_id, page in pages_data.items():
247
+ if "extract" in page:
248
+ pages.append({
249
+ "source": "wikipedia",
250
+ "title": page.get("title", ""),
251
+ "text": page.get("extract", ""),
252
+ })
253
+
254
+ except Exception as e:
255
+ print(f"Error scraping Wikipedia: {e}")
256
+
257
+ return pages
258
+
259
+ def scrape_nist(
260
+ self,
261
+ element: str,
262
+ ) -> List[Dict]:
263
+ """Scrape NIST chemistry webbook for element data."""
264
+ data = []
265
+
266
+ url = "https://webbook.nist.gov/cgi/cbook.cgi"
267
+ params = {
268
+ "Formula": element,
269
+ "Units": "SI",
270
+ "Submit": "Submit",
271
+ }
272
+
273
+ if not self._check_robots_allowed(url):
274
+ print(f"Robots.txt disallows {url}")
275
+ return data
276
+
277
+ try:
278
+ self._rate_limit("nist")
279
+ response = self.session.get(url, params=params)
280
+ response.raise_for_status()
281
+
282
+ # Placeholder - would parse HTML tables
283
+ data.append({
284
+ "source": "nist",
285
+ "element": element,
286
+ "html": response.text[:1000],
287
+ })
288
+
289
+ except Exception as e:
290
+ print(f"Error scraping NIST: {e}")
291
+
292
+ return data
293
+
294
+ def scrape_nasa(
295
+ self,
296
+ query: str,
297
+ max_results: int = 50,
298
+ ) -> List[Dict]:
299
+ """Scrape NASA technical reports."""
300
+ reports = []
301
+
302
+ url = "https://ntrs.nasa.gov/api/citations/search"
303
+ params = {
304
+ "q": query,
305
+ "page[size]": max_results,
306
+ }
307
+
308
+ if not self._check_robots_allowed(url):
309
+ print(f"Robots.txt disallows {url}")
310
+ return reports
311
+
312
+ try:
313
+ self._rate_limit("nasa")
314
+ response = self.session.get(url, params=params)
315
+ response.raise_for_status()
316
+
317
+ data = response.json()
318
+ for item in data.get("data", [])[:10]:
319
+ reports.append({
320
+ "source": "nasa",
321
+ "title": item.get("attributes", {}).get("title", ""),
322
+ "abstract": item.get("attributes", {}).get("abstract", ""),
323
+ "download_url": item.get("attributes", {}).get("downloads", {}).get("pdf", ""),
324
+ })
325
+
326
+ except Exception as e:
327
+ print(f"Error scraping NASA: {e}")
328
+
329
+ return reports
330
+
331
+ def save_results(
332
+ self,
333
+ results: List[Dict],
334
+ filename: str,
335
+ ):
336
+ """Save scraped results to JSON."""
337
+ output_path = self.output_dir / filename
338
+ with open(output_path, "w", encoding="utf-8") as f:
339
+ json.dump(results, f, indent=2, ensure_ascii=False)
340
+ print(f"Saved {len(results)} results to {output_path}")
341
+
342
+ def scrape_all_sources(
343
+ self,
344
+ queries: Dict[str, str],
345
+ max_per_source: int = 50,
346
+ ) -> Dict[str, List[Dict]]:
347
+ """
348
+ Scrape all sources with given queries.
349
+
350
+ Args:
351
+ queries: Dict mapping source name to query string
352
+ max_per_source: Max results per source
353
+
354
+ Returns:
355
+ Dict mapping source to list of results
356
+ """
357
+ all_results = {}
358
+
359
+ for source, query in queries.items():
360
+ if source not in self.SOURCES:
361
+ print(f"Unknown source: {source}")
362
+ continue
363
+
364
+ print(f"Scraping {source} with query: {query}")
365
+
366
+ if source == "arxiv":
367
+ results = self.scrape_arxiv(query, max_results=max_per_source)
368
+ elif source == "pubmed":
369
+ results = self.scrape_pubmed(query, max_results=max_per_source)
370
+ elif source == "wikipedia":
371
+ results = self.scrape_wikipedia(query, max_pages=max_per_source)
372
+ elif source == "nist":
373
+ results = self.scrape_nist(query)
374
+ elif source == "nasa":
375
+ results = self.scrape_nasa(query, max_results=max_per_source)
376
+ else:
377
+ results = []
378
+
379
+ all_results[source] = results
380
+
381
+ # Save intermediate results
382
+ self.save_results(results, f"{source}_results.json")
383
+
384
+ return all_results
385
+
386
+
387
+ def test_scraper():
388
+ """Test the scraper (limited)."""
389
+ scraper = VortexScienceScraper()
390
+
391
+ # Test Wikipedia (lightweight)
392
+ print("Testing Wikipedia scrape...")
393
+ results = scraper.scrape_wikipedia("quantum mechanics", max_pages=2)
394
+ print(f"Got {len(results)} Wikipedia pages")
395
+
396
+ # Test arXiv (rate limited)
397
+ print("Testing arXiv scrape...")
398
+ results = scraper.scrape_arxiv("quantum", max_results=5)
399
+ print(f"Got {len(results)} arXiv papers")
400
+
401
+ print("Scraper test passed!")
402
+
403
+
404
+ if __name__ == "__main__":
405
+ test_scraper()
inference.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Inference script for Vortex models.
4
+ Supports both CUDA and MPS backends.
5
+ """
6
+
7
+ import argparse
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ import torch
12
+
13
+ from configs.vortex_7b_config import VORTEX_7B_CONFIG
14
+ from configs.vortex_13b_config import VORTEX_13B_CONFIG
15
+
16
+ from models.vortex_model import VortexModel
17
+ from tokenizer.vortex_tokenizer import VortexScienceTokenizer
18
+ from inference.cuda_optimize import optimize_for_cuda, profile_model
19
+ from inference.mps_optimize import optimize_for_mps, profile_model_mps
20
+
21
+
22
+ def parse_args():
23
+ parser = argparse.ArgumentParser(description="Run inference with Vortex model")
24
+ parser.add_argument("--model_path", type=str, required=True,
25
+ help="Path to trained model checkpoint")
26
+ parser.add_argument("--config", type=str, default=None,
27
+ help="Path to model config (if not in checkpoint)")
28
+ parser.add_argument("--tokenizer_path", type=str, default=None,
29
+ help="Path to tokenizer")
30
+ parser.add_argument("--model_size", type=str, choices=["7b", "13b"], default="7b",
31
+ help="Model size for config")
32
+ parser.add_argument("--device", type=str, default="cuda",
33
+ choices=["cuda", "mps", "cpu"],
34
+ help="Device to run on")
35
+ parser.add_argument("--use_mps", action="store_true",
36
+ help="Use MPS backend (Apple Silicon)")
37
+ parser.add_argument("--quantization", type=str, choices=[None, "int8", "int4"], default=None,
38
+ help="Apply quantization (CUDA only)")
39
+ parser.add_argument("--flash_attention", action="store_true",
40
+ help="Use Flash Attention 2 (CUDA only)")
41
+ parser.add_argument("--torch_compile", action="store_true",
42
+ help="Use torch.compile")
43
+ parser.add_argument("--prompt", type=str, default=None,
44
+ help="Input prompt for generation")
45
+ parser.add_argument("--interactive", action="store_true",
46
+ help="Run in interactive mode")
47
+ parser.add_argument("--max_new_tokens", type=int, default=100,
48
+ help="Maximum new tokens to generate")
49
+ parser.add_argument("--temperature", type=float, default=0.8,
50
+ help="Sampling temperature")
51
+ parser.add_argument("--top_p", type=float, default=0.9,
52
+ help="Top-p sampling")
53
+ parser.add_argument("--profile", action="store_true",
54
+ help="Profile performance")
55
+ return parser.parse_args()
56
+
57
+
58
+ def load_model(args):
59
+ """Load model with appropriate optimizations."""
60
+ # Load config
61
+ if args.config:
62
+ from configuration_vortex import VortexConfig
63
+ config = VortexConfig.from_pretrained(args.config)
64
+ else:
65
+ # Use default config for size
66
+ if args.model_size == "7b":
67
+ config_dict = VORTEX_7B_CONFIG
68
+ else:
69
+ config_dict = VORTEX_13B_CONFIG
70
+ from configuration_vortex import VortexConfig
71
+ config = VortexConfig(**config_dict)
72
+
73
+ # Create model
74
+ print("Creating model...")
75
+ model = VortexModel(config.to_dict())
76
+
77
+ # Load checkpoint
78
+ print(f"Loading checkpoint from {args.model_path}")
79
+ checkpoint = torch.load(args.model_path, map_location="cpu", weights_only=False)
80
+ if "model_state_dict" in checkpoint:
81
+ model.load_state_dict(checkpoint["model_state_dict"])
82
+ else:
83
+ model.load_state_dict(checkpoint)
84
+ print("Model loaded")
85
+
86
+ # Apply optimizations
87
+ device = torch.device(args.device)
88
+ if args.use_mps or args.device == "mps":
89
+ print("Optimizing for MPS...")
90
+ model = optimize_for_mps(model, config.to_dict(), use_sdpa=True)
91
+ else:
92
+ print("Optimizing for CUDA...")
93
+ model = optimize_for_cuda(
94
+ model,
95
+ config.to_dict(),
96
+ use_flash_attention=args.flash_attention,
97
+ use_torch_compile=args.torch_compile,
98
+ quantization=args.quantization,
99
+ )
100
+
101
+ model = model.to(device)
102
+ model.eval()
103
+
104
+ return model, config
105
+
106
+
107
+ def load_tokenizer(args):
108
+ """Load tokenizer."""
109
+ tokenizer_path = args.tokenizer_path
110
+ if not tokenizer_path:
111
+ # Try to find in model directory
112
+ model_dir = Path(args.model_path).parent
113
+ tokenizer_path = model_dir / "vortex_tokenizer.json"
114
+
115
+ if tokenizer_path and Path(tokenizer_path).exists():
116
+ from tokenization_vortex import VortexTokenizer
117
+ tokenizer = VortexTokenizer.from_pretrained(str(model_dir))
118
+ else:
119
+ print("Warning: No tokenizer found, using dummy tokenizer")
120
+ class DummyTokenizer:
121
+ def __call__(self, text, **kwargs):
122
+ return {"input_ids": torch.tensor([[1, 2, 3]])}
123
+ def decode(self, ids, **kwargs):
124
+ return "dummy"
125
+ tokenizer = DummyTokenizer()
126
+
127
+ return tokenizer
128
+
129
+
130
+ def generate_text(model, tokenizer, prompt, args):
131
+ """Generate text from prompt."""
132
+ # Tokenize
133
+ inputs = tokenizer(
134
+ prompt,
135
+ return_tensors="pt",
136
+ padding=False,
137
+ truncation=True,
138
+ max_length=model.config.max_seq_len - args.max_new_tokens,
139
+ )
140
+ input_ids = inputs["input_ids"].to(next(model.parameters()).device)
141
+
142
+ # Generate
143
+ with torch.no_grad():
144
+ if hasattr(model, 'generate'):
145
+ output_ids = model.generate(
146
+ input_ids,
147
+ max_new_tokens=args.max_new_tokens,
148
+ temperature=args.temperature,
149
+ top_p=args.top_p,
150
+ do_sample=True,
151
+ pad_token_id=tokenizer.pad_token_id,
152
+ )
153
+ else:
154
+ # Manual generation
155
+ for _ in range(args.max_new_tokens):
156
+ outputs = model(input_ids)
157
+ next_token_logits = outputs["logits"][:, -1, :]
158
+ next_token = torch.multinomial(
159
+ torch.softmax(next_token_logits / args.temperature, dim=-1),
160
+ num_samples=1,
161
+ )
162
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
163
+
164
+ # Check for EOS
165
+ if next_token.item() == tokenizer.eos_token_id:
166
+ break
167
+
168
+ # Decode
169
+ generated = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True)
170
+ return generated
171
+
172
+
173
+ def main():
174
+ args = parse_args()
175
+
176
+ # Load model and tokenizer
177
+ model, config = load_model(args)
178
+ tokenizer = load_tokenizer(args)
179
+
180
+ print(f"Model loaded on {next(model.parameters()).device}")
181
+ print(f"Model parameters: {model.get_num_params():,}")
182
+
183
+ # Profile if requested
184
+ if args.profile:
185
+ print("Profiling...")
186
+ dummy_input = torch.randint(0, config.vocab_size, (1, 128)).to(next(model.parameters()).device)
187
+ if args.use_mps or args.device == "mps":
188
+ stats = profile_model_mps(model, dummy_input)
189
+ else:
190
+ stats = profile_model(model, dummy_input)
191
+ print("Profile results:")
192
+ for k, v in stats.items():
193
+ print(f" {k}: {v:.4f}")
194
+ return
195
+
196
+ # Interactive mode
197
+ if args.interactive:
198
+ print("Interactive mode. Type 'quit' to exit.")
199
+ while True:
200
+ prompt = input("\nPrompt: ")
201
+ if prompt.lower() == "quit":
202
+ break
203
+ response = generate_text(model, tokenizer, prompt, args)
204
+ print(f"\nResponse: {response}")
205
+ elif args.prompt:
206
+ response = generate_text(model, tokenizer, args.prompt, args)
207
+ print(f"Response: {response}")
208
+ else:
209
+ print("No prompt provided. Use --prompt or --interactive.")
210
+
211
+
212
+ if __name__ == "__main__":
213
+ main()
modeling_vortex.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vortex model implementation for HuggingFace.
3
+ Integrates with transformers library.
4
+ """
5
+
6
+ from typing import Optional, Tuple, List, Dict, Any
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import PreTrainedModel, PretrainedConfig, GenerationConfig
10
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
11
+
12
+ from configuration_vortex import VortexConfig
13
+ from models.vortex_model import VortexModel
14
+
15
+
16
+ class VortexPreTrainedModel(PreTrainedModel):
17
+ """
18
+ Base class for Vortex models.
19
+ Handles loading/saving in HF format.
20
+ """
21
+ config_class = VortexConfig
22
+ base_model_prefix = "vortex"
23
+ supports_gradient_checkpointing = True
24
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
25
+
26
+ def _init_weights(self, module):
27
+ """Initialize weights."""
28
+ if isinstance(module, nn.Linear):
29
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
30
+ if module.bias is not None:
31
+ module.bias.data.zero_()
32
+ elif isinstance(module, nn.Embedding):
33
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
34
+ if module.padding_idx is not None:
35
+ module.weight.data[module.padding_idx].zero_()
36
+ elif isinstance(module, nn.LayerNorm):
37
+ module.bias.data.zero_()
38
+ module.weight.data.fill_(1.0)
39
+
40
+ def get_input_embeddings(self):
41
+ return self.vortex.embed_tokens
42
+
43
+ def set_input_embeddings(self, value):
44
+ self.vortex.embed_tokens = value
45
+
46
+ def get_output_embeddings(self):
47
+ return self.vortex.lm_head
48
+
49
+ def set_output_embeddings(self, new_embeddings):
50
+ self.vortex.lm_head = new_embeddings
51
+
52
+
53
+ class VortexForCausalLM(VortexPreTrainedModel):
54
+ """
55
+ Vortex model for causal language modeling.
56
+ """
57
+ _tied_weights_keys = ["vortex.lm_head.weight"]
58
+
59
+ def __init__(self, config: VortexConfig):
60
+ super().__init__(config)
61
+ self.config = config
62
+
63
+ # Build core model
64
+ self.vortex = VortexModel(config.to_dict())
65
+
66
+ # Initialize weights
67
+ self.apply(self._init_weights)
68
+
69
+ # Tie weights if configured
70
+ if self.config.tie_word_embeddings:
71
+ self.tie_weights()
72
+
73
+ def forward(
74
+ self,
75
+ input_ids: torch.LongTensor = None,
76
+ attention_mask: Optional[torch.Tensor] = None,
77
+ position_ids: Optional[torch.LongTensor] = None,
78
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
79
+ inputs_embeds: Optional[torch.FloatTensor] = None,
80
+ labels: Optional[torch.LongTensor] = None,
81
+ use_cache: Optional[bool] = None,
82
+ output_attentions: Optional[bool] = None,
83
+ output_hidden_states: Optional[bool] = None,
84
+ return_dict: Optional[bool] = None,
85
+ domain_ids: Optional[torch.LongTensor] = None,
86
+ domain_tags: Optional[torch.Tensor] = None,
87
+ text: Optional[List[str]] = None,
88
+ ) -> CausalLMOutputWithCrossAttentions:
89
+ """
90
+ Forward pass.
91
+
92
+ Args:
93
+ input_ids: Token IDs (batch, seq_len)
94
+ attention_mask: Attention mask (batch, seq_len)
95
+ labels: Labels for LM loss (batch, seq_len)
96
+ domain_ids: Domain IDs (batch,)
97
+ domain_tags: Domain tag masks (batch, seq_len, num_domains)
98
+ text: Original text strings (for science modules)
99
+ """
100
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
101
+
102
+ # Pass through Vortex model
103
+ outputs = self.vortex(
104
+ input_ids=input_ids,
105
+ attention_mask=attention_mask,
106
+ domain_ids=domain_ids,
107
+ domain_tags=domain_tags,
108
+ text=text,
109
+ return_dict=True,
110
+ )
111
+
112
+ logits = outputs["logits"]
113
+ last_hidden_state = outputs["last_hidden_state"]
114
+
115
+ loss = None
116
+ if labels is not None:
117
+ # Compute cross-entropy loss
118
+ shift_logits = logits[..., :-1, :].contiguous()
119
+ shift_labels = labels[..., 1:].contiguous()
120
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
121
+ loss = loss_fct(
122
+ shift_logits.view(-1, shift_logits.size(-1)),
123
+ shift_labels.view(-1),
124
+ )
125
+
126
+ if not return_dict:
127
+ output = (logits,) + (last_hidden_state,)
128
+ return (loss,) + output if loss is not None else output
129
+
130
+ return CausalLMOutputWithCrossAttentions(
131
+ loss=loss,
132
+ logits=logits,
133
+ hidden_states=last_hidden_state,
134
+ attentions=None,
135
+ )
136
+
137
+ def prepare_inputs_for_generation(
138
+ self,
139
+ input_ids,
140
+ past_key_values=None,
141
+ attention_mask=None,
142
+ **kwargs,
143
+ ):
144
+ """Prepare inputs for text generation."""
145
+ # Omit tokens that are already past
146
+ if past_key_values:
147
+ input_ids = input_ids[:, -1:]
148
+
149
+ return {
150
+ "input_ids": input_ids,
151
+ "attention_mask": attention_mask,
152
+ "past_key_values": past_key_values,
153
+ "use_cache": kwargs.get("use_cache", True),
154
+ }
155
+
156
+ def generate(
157
+ self,
158
+ input_ids: Optional[torch.LongTensor] = None,
159
+ inputs_embeds: Optional[torch.FloatTensor] = None,
160
+ **kwargs,
161
+ ):
162
+ """Generate text."""
163
+ from transformers import GenerationConfig
164
+
165
+ generation_config = kwargs.pop("generation_config", None)
166
+ if generation_config is None:
167
+ generation_config = GenerationConfig.from_model_config(self.config)
168
+
169
+ return super().generate(
170
+ input_ids=input_ids,
171
+ inputs_embeds=inputs_embeds,
172
+ generation_config=generation_config,
173
+ **kwargs,
174
+ )
175
+
176
+
177
+ # Register model for AutoModel
178
+ from transformers import AutoConfig, AutoModelForCausalLM
179
+
180
+ AutoConfig.register("vortex", VortexConfig)
181
+ AutoModelForCausalLM.register(VortexConfig, VortexForCausalLM)
182
+
183
+
184
+ def test_hf_integration():
185
+ """Test HuggingFace integration."""
186
+ from transformers import AutoConfig, AutoModelForCausalLM
187
+
188
+ # Create config
189
+ config = VortexConfig(
190
+ d_model=512,
191
+ num_layers=2,
192
+ num_heads=8,
193
+ vocab_size=1000,
194
+ )
195
+
196
+ # Create model
197
+ model = VortexForCausalLM(config)
198
+ print(f"Model parameters: {model.get_num_parameters():,}")
199
+
200
+ # Test forward
201
+ batch_size = 2
202
+ seq_len = 32
203
+ input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
204
+ labels = torch.randint(0, config.vocab_size, (batch_size, seq_len))
205
+
206
+ outputs = model(input_ids=input_ids, labels=labels)
207
+ print(f"Loss: {outputs.loss.item():.4f}")
208
+ print(f"Logits shape: {outputs.logits.shape}")
209
+
210
+ # Test save/load
211
+ model.save_pretrained("./test_vortex_model")
212
+ config.save_pretrained("./test_vortex_model")
213
+
214
+ loaded_config = AutoConfig.from_pretrained("./test_vortex_model")
215
+ loaded_model = AutoModelForCausalLM.from_pretrained("./test_vortex_model")
216
+ print(f"Loaded model type: {type(loaded_model)}")
217
+
218
+ print("HF integration test passed!")
219
+
220
+
221
+ if __name__ == "__main__":
222
+ test_hf_integration()
models/__pycache__/attention_layer.cpython-313.pyc ADDED
Binary file (15.5 kB). View file
 
models/__pycache__/scigate_ffn.cpython-313.pyc ADDED
Binary file (7.94 kB). View file
 
models/__pycache__/ssm_layer.cpython-313.pyc ADDED
Binary file (10.2 kB). View file
 
models/__pycache__/vortex_model.cpython-313.pyc ADDED
Binary file (14.6 kB). View file
 
models/attention_layer.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VortexLocalAttention: Local windowed attention with global token support.
3
+ Uses a sliding window of 512 tokens for efficiency, with special handling
4
+ for global tokens that can attend across the entire sequence.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Optional, Tuple
11
+
12
+
13
+ class VortexLocalAttention(nn.Module):
14
+ """
15
+ Local windowed attention with window_size=512.
16
+ Science documents have strong local coherence — equations reference
17
+ nearby text, not distant paragraphs.
18
+ Global tokens (special [SCIENCE] tokens) attend to everything.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ d_model: int,
24
+ num_heads: int,
25
+ window_size: int = 512,
26
+ use_flash_attention: bool = True,
27
+ ):
28
+ """
29
+ Initialize local windowed attention.
30
+
31
+ Args:
32
+ d_model: Model dimension
33
+ num_heads: Number of attention heads
34
+ window_size: Size of local attention window
35
+ use_flash_attention: Use Flash Attention 2 if available (CUDA only)
36
+ """
37
+ super().__init__()
38
+ self.d_model = d_model
39
+ self.num_heads = num_heads
40
+ self.head_dim = d_model // num_heads
41
+ self.window_size = window_size
42
+ self.use_flash_attention = use_flash_attention
43
+
44
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
45
+
46
+ # QKV projection
47
+ self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
48
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
49
+
50
+ # Global token projection (for tokens that attend globally)
51
+ self.global_qkv = nn.Linear(d_model, d_model * 3, bias=False)
52
+
53
+ # Initialize weights
54
+ self._initialize_weights()
55
+
56
+ def _initialize_weights(self):
57
+ """Initialize weights."""
58
+ for module in [self.qkv, self.global_qkv, self.out_proj]:
59
+ if hasattr(module, 'weight'):
60
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
61
+
62
+ def forward(
63
+ self,
64
+ x: torch.Tensor,
65
+ global_mask: Optional[torch.Tensor] = None,
66
+ attention_mask: Optional[torch.Tensor] = None,
67
+ ) -> torch.Tensor:
68
+ """
69
+ Forward pass with local windowed attention.
70
+
71
+ Args:
72
+ x: Input tensor (batch, seq_len, d_model)
73
+ global_mask: Boolean mask indicating which tokens are global (attend everywhere)
74
+ Shape: (batch, seq_len) or None
75
+ attention_mask: Optional padding mask (batch, seq_len)
76
+
77
+ Returns:
78
+ Output tensor (batch, seq_len, d_model)
79
+ """
80
+ batch, seq_len, _ = x.shape
81
+ device = x.device
82
+ dtype = x.dtype
83
+
84
+ if global_mask is None:
85
+ global_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=device)
86
+
87
+ # Compute QKV for all tokens
88
+ qkv = self.qkv(x)
89
+ q, k, v = qkv.chunk(3, dim=-1)
90
+
91
+ # Reshape for multi-head attention
92
+ q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
93
+ k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
94
+ v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
95
+
96
+ # Compute global token QKV separately
97
+ if global_mask.any():
98
+ global_qkv = self.global_qkv(x)
99
+ gq, gk, gv = global_qkv.chunk(3, dim=-1)
100
+ gq = gq.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
101
+ gk = gk.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
102
+ gv = gv.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
103
+
104
+ # Build output tensor
105
+ output = torch.zeros_like(x)
106
+
107
+ # Process each position
108
+ for t in range(seq_len):
109
+ # Determine window
110
+ window_start = max(0, t - self.window_size // 2)
111
+ window_end = min(seq_len, t + self.window_size // 2 + 1)
112
+ window_len = window_end - window_start
113
+
114
+ # Get window indices
115
+ window_indices = slice(window_start, window_end)
116
+
117
+ # Extract window queries (for position t)
118
+ q_t = q[:, :, t:t+1, :] # (batch, heads, 1, head_dim)
119
+
120
+ # Determine which keys/values to use
121
+ # Local tokens: only those in window
122
+ # Global tokens: all positions (if they are global)
123
+ k_window = k[:, :, window_indices, :]
124
+ v_window = v[:, :, window_indices, :]
125
+
126
+ # Build full key/value set including global tokens
127
+ # Global tokens attend to all positions
128
+ if global_mask.any():
129
+ # Find global positions
130
+ global_positions = global_mask[0] # (seq_len) - assume same across batch
131
+ if global_positions.any():
132
+ gk_all = gk[:, :, :, :] # All global keys
133
+ gv_all = gv[:, :, :, :]
134
+
135
+ # Concatenate window keys with global keys
136
+ k_full = torch.cat([k_window, gk_all], dim=2)
137
+ v_full = torch.cat([v_window, gv_all], dim=2)
138
+ else:
139
+ k_full = k_window
140
+ v_full = v_window
141
+ else:
142
+ k_full = k_window
143
+ v_full = v_window
144
+
145
+ # Compute attention scores
146
+ # q_t: (batch, heads, 1, head_dim)
147
+ # k_full: (batch, heads, window_len + num_global, head_dim)
148
+ attn_scores = torch.matmul(q_t, k_full.transpose(-2, -1)) / (self.head_dim ** 0.5)
149
+ # (batch, heads, 1, k_len)
150
+
151
+ # Apply attention mask if provided
152
+ if attention_mask is not None:
153
+ mask_t = attention_mask[:, window_indices].unsqueeze(1).unsqueeze(2)
154
+ attn_scores = attn_scores.masked_fill(mask_t == 0, -1e9)
155
+
156
+ # Softmax
157
+ attn_weights = F.softmax(attn_scores, dim=-1)
158
+
159
+ # Weighted sum
160
+ attn_output = torch.matmul(attn_weights, v_full)
161
+ # (batch, heads, 1, head_dim)
162
+
163
+ # Reshape and project
164
+ attn_output = attn_output.transpose(1, 2).contiguous()
165
+ attn_output = attn_output.view(batch, 1, self.d_model)
166
+ attn_output = self.out_proj(attn_output)
167
+
168
+ # Place in output
169
+ output[:, t:t+1, :] = attn_output
170
+
171
+ return output
172
+
173
+ def forward_optimized(
174
+ self,
175
+ x: torch.Tensor,
176
+ global_mask: Optional[torch.Tensor] = None,
177
+ attention_mask: Optional[torch.Tensor] = None,
178
+ ) -> torch.Tensor:
179
+ """
180
+ Optimized forward pass using Flash Attention or efficient windowed attention.
181
+ This is a placeholder for actual Flash Attention integration.
182
+ """
183
+ batch, seq_len, _ = x.shape
184
+
185
+ if self.use_flash_attention and self.window_size >= seq_len:
186
+ # For short sequences, can use full attention
187
+ return self._flash_attention_forward(x, attention_mask)
188
+ else:
189
+ # Use windowed attention
190
+ return self._windowed_attention_forward(x, global_mask, attention_mask)
191
+
192
+ def _flash_attention_forward(
193
+ self,
194
+ x: torch.Tensor,
195
+ attention_mask: Optional[torch.Tensor] = None,
196
+ ) -> torch.Tensor:
197
+ """
198
+ Use Flash Attention 2 if available.
199
+ Requires: pip install flash-attn
200
+ """
201
+ try:
202
+ from flash_attn import flash_attn_func
203
+
204
+ batch, seq_len, _ = x.shape
205
+ qkv = self.qkv(x)
206
+ q, k, v = qkv.chunk(3, dim=-1)
207
+
208
+ # Reshape for flash attention
209
+ q = q.view(batch, seq_len, self.num_heads, self.head_dim)
210
+ k = k.view(batch, seq_len, self.num_heads, self.head_dim)
211
+ v = v.view(batch, seq_len, self.num_heads, self.head_dim)
212
+
213
+ # Flash attention expects (batch, seq_len, num_heads, head_dim)
214
+ # and returns same shape
215
+ if attention_mask is not None:
216
+ # Flash attention uses causal mask or padding mask
217
+ output = flash_attn_func(
218
+ q, k, v,
219
+ causal=False,
220
+ softmax_scale=1.0 / (self.head_dim ** 0.5),
221
+ )
222
+ else:
223
+ output = flash_attn_func(
224
+ q, k, v,
225
+ causal=False,
226
+ )
227
+
228
+ output = output.view(batch, seq_len, self.d_model)
229
+ return self.out_proj(output)
230
+
231
+ except ImportError:
232
+ print("Flash Attention not available, falling back to standard attention")
233
+ return self._standard_attention(x, attention_mask)
234
+
235
+ def _standard_attention(
236
+ self,
237
+ x: torch.Tensor,
238
+ attention_mask: Optional[torch.Tensor] = None,
239
+ ) -> torch.Tensor:
240
+ """Standard full attention (quadratic)."""
241
+ batch, seq_len, _ = x.shape
242
+ qkv = self.qkv(x)
243
+ q, k, v = qkv.chunk(3, dim=-1)
244
+
245
+ q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
246
+ k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
247
+ v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
248
+
249
+ # Compute attention scores
250
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
251
+
252
+ if attention_mask is not None:
253
+ attn_scores = attn_scores.masked_fill(
254
+ attention_mask.unsqueeze(1).unsqueeze(2) == 0,
255
+ -1e9
256
+ )
257
+
258
+ attn_weights = F.softmax(attn_scores, dim=-1)
259
+ attn_output = torch.matmul(attn_weights, v)
260
+
261
+ attn_output = attn_output.transpose(1, 2).contiguous()
262
+ attn_output = attn_output.view(batch, seq_len, self.d_model)
263
+ return self.out_proj(attn_output)
264
+
265
+ def _windowed_attention_forward(
266
+ self,
267
+ x: torch.Tensor,
268
+ global_mask: Optional[torch.Tensor] = None,
269
+ attention_mask: Optional[torch.Tensor] = None,
270
+ ) -> torch.Tensor:
271
+ """
272
+ Efficient windowed attention implementation.
273
+ Uses unfold to extract windows and batched matrix multiply.
274
+ """
275
+ batch, seq_len, _ = x.shape
276
+ device = x.device
277
+
278
+ if global_mask is None:
279
+ global_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=device)
280
+
281
+ # Compute QKV
282
+ qkv = self.qkv(x)
283
+ q, k, v = qkv.chunk(3, dim=-1)
284
+
285
+ # Reshape: (batch, seq_len, num_heads, head_dim)
286
+ q = q.view(batch, seq_len, self.num_heads, self.head_dim)
287
+ k = k.view(batch, seq_len, self.num_heads, self.head_dim)
288
+ v = v.view(batch, seq_len, self.num_heads, self.head_dim)
289
+
290
+ # Pad sequence for windowing
291
+ pad_len = self.window_size // 2
292
+ k_padded = F.pad(k, (0, 0, 0, 0, pad_len, pad_len))
293
+ v_padded = F.pad(v, (0, 0, 0, 0, pad_len, pad_len))
294
+
295
+ # Extract windows using unfold
296
+ # (batch, seq_len, num_heads, head_dim) -> (batch, seq_len, window_size, num_heads, head_dim)
297
+ k_windows = k_padded.unfold(1, self.window_size, 1)
298
+ v_windows = v_padded.unfold(1, self.window_size, 1)
299
+
300
+ # Permute to (batch, seq_len, num_heads, window_size, head_dim)
301
+ k_windows = k_windows.permute(0, 1, 3, 2, 4)
302
+ v_windows = v_windows.permute(0, 1, 3, 2, 4)
303
+
304
+ # Compute attention for each position
305
+ # q: (batch, seq_len, num_heads, 1, head_dim)
306
+ q_expanded = q.unsqueeze(3)
307
+ k_windows = k_windows
308
+
309
+ # Scores: (batch, seq_len, num_heads, 1, window_size)
310
+ attn_scores = torch.matmul(q_expanded, k_windows.transpose(-2, -1)) / (self.head_dim ** 0.5)
311
+ attn_scores = attn_scores.squeeze(3) # (batch, seq_len, num_heads, window_size)
312
+
313
+ # Apply softmax
314
+ attn_weights = F.softmax(attn_scores, dim=-1)
315
+
316
+ # Weighted sum
317
+ attn_output = torch.matmul(attn_weights.unsqueeze(3), v_windows).squeeze(3)
318
+ # (batch, seq_len, num_heads, head_dim)
319
+
320
+ # Concatenate heads
321
+ attn_output = attn_output.view(batch, seq_len, self.d_model)
322
+
323
+ # Add global token contribution if any
324
+ if global_mask.any():
325
+ # Compute full attention for global tokens only
326
+ # This is a simplified version - in practice would be optimized
327
+ global_indices = global_mask[0].nonzero(as_tuple=True)[0]
328
+ if len(global_indices) > 0:
329
+ # For positions with global tokens, add full attention
330
+ # (simplified: compute full attention for all)
331
+ full_attn = self._standard_attention(x, attention_mask)
332
+ # Blend: local for most, full for global positions
333
+ attn_output = torch.where(
334
+ global_mask.unsqueeze(-1),
335
+ full_attn,
336
+ attn_output
337
+ )
338
+
339
+ return self.out_proj(attn_output)
340
+
341
+
342
+ def test_vortex_local_attention():
343
+ """Test the VortexLocalAttention layer."""
344
+ batch_size = 2
345
+ seq_len = 256
346
+ d_model = 4096
347
+ num_heads = 32
348
+ window_size = 512
349
+
350
+ attn = VortexLocalAttention(d_model, num_heads, window_size, use_flash_attention=False)
351
+ x = torch.randn(batch_size, seq_len, d_model)
352
+
353
+ # Forward pass
354
+ output = attn(x)
355
+ print(f"Input shape: {x.shape}")
356
+ print(f"Output shape: {output.shape}")
357
+ assert output.shape == x.shape, f"Expected {x.shape}, got {output.shape}"
358
+
359
+ # With global mask
360
+ global_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
361
+ global_mask[0, 0] = True # First token is global
362
+ global_mask[1, -1] = True # Last token is global
363
+ output2 = attn(x, global_mask=global_mask)
364
+ assert output2.shape == x.shape
365
+
366
+ print("VortexLocalAttention test passed!")
367
+
368
+
369
+ if __name__ == "__main__":
370
+ test_vortex_local_attention()
models/science_modules/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Science modules package.
3
+ """
4
+
5
+ from .equation_module import EquationModule
6
+ from .numerical_module import NumericalReasoningModule
7
+ from .citation_module import CitationModule
8
+ from .molecular_module import MolecularModule
9
+
10
+ __all__ = [
11
+ "EquationModule",
12
+ "NumericalReasoningModule",
13
+ "CitationModule",
14
+ "MolecularModule",
15
+ ]
models/science_modules/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (490 Bytes). View file
 
models/science_modules/__pycache__/citation_module.cpython-313.pyc ADDED
Binary file (9.3 kB). View file
 
models/science_modules/__pycache__/equation_module.cpython-313.pyc ADDED
Binary file (10.4 kB). View file
 
models/science_modules/__pycache__/molecular_module.cpython-313.pyc ADDED
Binary file (13 kB). View file
 
models/science_modules/__pycache__/numerical_module.cpython-313.pyc ADDED
Binary file (10 kB). View file
 
models/science_modules/citation_module.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CitationModule: Understands scientific citation structure.
3
+ Detects citation spans, tracks provenance, and estimates claim confidence.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import re
10
+ from typing import Optional, Tuple, List
11
+
12
+
13
+ class CitationModule(nn.Module):
14
+ """
15
+ Understands scientific citation structure.
16
+ - Detects citation spans [Author, Year] or (1) style
17
+ - Learns that cited claims carry different epistemic weight
18
+ - Distinguishes established facts vs recent/contested findings
19
+ - Tracks claim provenance through the context window
20
+ """
21
+
22
+ def __init__(self, d_model: int):
23
+ """
24
+ Initialize CitationModule.
25
+
26
+ Args:
27
+ d_model: Model dimension
28
+ """
29
+ super().__init__()
30
+ self.d_model = d_model
31
+
32
+ # Citation span detector (3 classes: none, inline, reference)
33
+ # Inline: (Author, Year) or [1]
34
+ # Reference: full citation at end of paper
35
+ self.citation_detector = nn.Linear(d_model, 3)
36
+
37
+ # Provenance gate: modulates information flow based on citation context
38
+ self.provenance_gate = nn.Linear(d_model, d_model)
39
+
40
+ # Claim confidence head: estimates how well-supported a claim is
41
+ self.confidence_head = nn.Linear(d_model, 1)
42
+
43
+ # Citation type embeddings
44
+ self.citation_type_embedding = nn.Embedding(3, d_model)
45
+
46
+ # Initialize weights
47
+ self._initialize_weights()
48
+
49
+ def _initialize_weights(self):
50
+ """Initialize weights."""
51
+ for module in [self.citation_detector, self.provenance_gate, self.confidence_head, self.citation_type_embedding]:
52
+ if hasattr(module, 'weight'):
53
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
54
+ if hasattr(module, 'bias') and module.bias is not None:
55
+ nn.init.zeros_(module.bias)
56
+
57
+ def detect_citation_spans(
58
+ self,
59
+ text: str,
60
+ ) -> List[Tuple[int, int, str]]:
61
+ """
62
+ Detect citation spans in text.
63
+ Supports: (Author, Year), [1], [Author, Year], et al.
64
+
65
+ Args:
66
+ text: Input text string
67
+
68
+ Returns:
69
+ List of (start_char, end_char, citation_type)
70
+ citation_type: "inline" or "reference"
71
+ """
72
+ spans = []
73
+
74
+ # Pattern 1: (Author, Year) or (Author Year)
75
+ for match in re.finditer(r'\([A-Za-z\s]+(?:et al\.)?,?\s*\d{4}\)', text):
76
+ spans.append((match.start(), match.end(), "inline"))
77
+
78
+ # Pattern 2: [1] or [1-3] or [1,2,3]
79
+ for match in re.finditer(r'\[\d+(?:[-,]\d+)*\]', text):
80
+ spans.append((match.start(), match.end(), "inline"))
81
+
82
+ # Pattern 3: [Author, Year]
83
+ for match in re.finditer(r'\[[A-Za-z\s]+,?\s*\d{4}\]', text):
84
+ spans.append((match.start(), match.end(), "inline"))
85
+
86
+ # Pattern 4: et al. (often indicates citation)
87
+ for match in re.finditer(r'\bet al\.\b', text):
88
+ spans.append((match.start(), match.end(), "inline"))
89
+
90
+ return spans
91
+
92
+ def forward(
93
+ self,
94
+ x: torch.Tensor,
95
+ text: Optional[List[str]] = None,
96
+ citation_spans: Optional[List[List[Tuple[int, int, str]]]] = None,
97
+ ) -> torch.Tensor:
98
+ """
99
+ Forward pass through citation module.
100
+
101
+ Args:
102
+ x: Input tensor (batch, seq_len, d_model)
103
+ text: Optional original text strings
104
+ citation_spans: Optional pre-computed citation spans per batch
105
+
106
+ Returns:
107
+ Citation-enhanced representation (batch, seq_len, d_model)
108
+ """
109
+ batch, seq_len, d_model = x.shape
110
+
111
+ # Detect citation spans
112
+ if citation_spans is None and text is not None:
113
+ citation_spans = []
114
+ for b in range(batch):
115
+ spans = self.detect_citation_spans(text[b])
116
+ # Convert char spans to token spans (approximate)
117
+ token_spans = []
118
+ for start_char, end_char, ctype in spans:
119
+ start_tok = max(0, start_char // 4)
120
+ end_tok = min(seq_len, end_char // 4 + 1)
121
+ token_spans.append((start_tok, end_tok, ctype))
122
+ citation_spans.append(token_spans)
123
+
124
+ # Compute citation type logits
125
+ citation_logits = self.citation_detector(x) # (batch, seq_len, 3)
126
+ citation_probs = F.softmax(citation_logits, dim=-1)
127
+
128
+ # Apply citation-specific transformations
129
+ output = x.clone()
130
+
131
+ if citation_spans:
132
+ for b in range(batch):
133
+ spans_b = citation_spans[b] if b < len(citation_spans) else []
134
+
135
+ for start_tok, end_tok, ctype in spans_b:
136
+ if end_tok <= start_tok:
137
+ continue
138
+
139
+ # Get citation type embedding
140
+ if ctype == "inline":
141
+ type_id = 1
142
+ elif ctype == "reference":
143
+ type_id = 2
144
+ else:
145
+ type_id = 0
146
+
147
+ type_emb = self.citation_type_embedding(
148
+ torch.tensor(type_id, device=x.device)
149
+ )
150
+
151
+ # Apply provenance gate to citation span
152
+ span_slice = x[b, start_tok:end_tok, :]
153
+ gated = span_slice * torch.sigmoid(self.provenance_gate(span_slice))
154
+
155
+ # Add citation type embedding
156
+ gated = gated + type_emb.unsqueeze(0).unsqueeze(0)
157
+
158
+ output[b, start_tok:end_tok, :] = gated
159
+
160
+ # Compute confidence scores (for auxiliary loss)
161
+ confidence = torch.sigmoid(self.confidence_head(x)) # (batch, seq_len, 1)
162
+
163
+ return output, confidence
164
+
165
+ def compute_citation_loss(
166
+ self,
167
+ x: torch.Tensor,
168
+ citation_mask: torch.Tensor,
169
+ confidence: torch.Tensor,
170
+ ) -> torch.Tensor:
171
+ """
172
+ Compute auxiliary loss for citation detection and confidence.
173
+
174
+ Args:
175
+ x: Input tensor (batch, seq_len, d_model)
176
+ citation_mask: Ground truth citation mask (batch, seq_len), 1 if token is in citation
177
+ confidence: Predicted confidence scores (batch, seq_len, 1)
178
+
179
+ Returns:
180
+ Combined citation loss
181
+ """
182
+ # Citation detection loss
183
+ logits = self.citation_detector(x) # (batch, seq_len, 3)
184
+ detection_loss = F.cross_entropy(
185
+ logits.view(-1, 3),
186
+ citation_mask.long().view(-1),
187
+ )
188
+
189
+ # Confidence calibration loss (encourage high confidence for true citations)
190
+ confidence_loss = F.mse_loss(
191
+ confidence.squeeze(-1),
192
+ citation_mask.float(),
193
+ )
194
+
195
+ return detection_loss + 0.1 * confidence_loss
196
+
197
+
198
+ def test_citation_module():
199
+ """Test CitationModule."""
200
+ d_model = 512
201
+ batch_size = 2
202
+ seq_len = 128
203
+
204
+ module = CitationModule(d_model)
205
+
206
+ x = torch.randn(batch_size, seq_len, d_model)
207
+ text = [
208
+ "The theory of relativity (Einstein, 1905) revolutionized physics. See also [1, 2].",
209
+ "According to Smith et al., the results are significant. Further reading: [Doe, 2020]."
210
+ ]
211
+
212
+ output, confidence = module(x, text=text)
213
+ print(f"Input shape: {x.shape}")
214
+ print(f"Output shape: {output.shape}")
215
+ print(f"Confidence shape: {confidence.shape}")
216
+ assert output.shape == x.shape
217
+ assert confidence.shape == (batch_size, seq_len, 1)
218
+
219
+ # Test loss
220
+ citation_mask = torch.zeros(batch_size, seq_len)
221
+ citation_mask[0, 20:25] = 1.0 # Simulate citation span
222
+ citation_mask[1, 10:18] = 1.0
223
+ loss = module.compute_citation_loss(x, citation_mask, confidence)
224
+ print(f"Citation loss: {loss.item():.4f}")
225
+
226
+ print("CitationModule test passed!")
227
+
228
+
229
+ if __name__ == "__main__":
230
+ test_citation_module()
models/science_modules/equation_module.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EquationModule: Specialized processing for mathematical equations and LaTeX.
3
+ Detects equation spans, applies equation-specific attention, and learns
4
+ structural representations of mathematical expressions.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import re
11
+ from typing import Optional, Tuple, List
12
+
13
+
14
+ class EquationModule(nn.Module):
15
+ """
16
+ Specialized processing for mathematical equations and LaTeX.
17
+ - Detects equation spans in input (between $ $ or \[ \] delimiters)
18
+ - Applies equation-specific attention patterns within equation spans
19
+ - Learns structural representations of mathematical expressions
20
+ - Tree-aware: understands operator precedence and nesting
21
+ """
22
+
23
+ def __init__(self, d_model: int, num_heads: int = 8):
24
+ """
25
+ Initialize EquationModule.
26
+
27
+ Args:
28
+ d_model: Model dimension
29
+ num_heads: Number of heads for equation-specific attention
30
+ """
31
+ super().__init__()
32
+ self.d_model = d_model
33
+
34
+ # Equation span detector (lightweight linear classifier)
35
+ self.span_detector = nn.Linear(d_model, 1)
36
+
37
+ # Equation-specific transformer (shallow, 2 layers)
38
+ encoder_layer = nn.TransformerEncoderLayer(
39
+ d_model=d_model,
40
+ nhead=num_heads,
41
+ dim_feedforward=d_model * 4,
42
+ activation=F.silu,
43
+ batch_first=True,
44
+ dropout=0.1,
45
+ )
46
+ self.equation_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
47
+
48
+ # Merge equation representations back into main stream
49
+ self.merge = nn.Linear(d_model * 2, d_model)
50
+
51
+ # LaTeX structure awareness (simple positional encoding for tree depth)
52
+ self.depth_embedding = nn.Embedding(10, d_model) # Max depth 10
53
+
54
+ # Initialize weights
55
+ self._initialize_weights()
56
+
57
+ def _initialize_weights(self):
58
+ """Initialize weights."""
59
+ for module in [self.span_detector, self.merge, self.depth_embedding]:
60
+ if hasattr(module, 'weight'):
61
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
62
+ if hasattr(module, 'bias') and module.bias is not None:
63
+ nn.init.zeros_(module.bias)
64
+
65
+ def detect_equation_spans(
66
+ self,
67
+ text: str,
68
+ token_ids: Optional[torch.Tensor] = None,
69
+ ) -> List[Tuple[int, int]]:
70
+ """
71
+ Detect equation spans in text using delimiters.
72
+ Supports: $...$, $$...$$, \[...\], \(...\)
73
+
74
+ Args:
75
+ text: Input text string
76
+ token_ids: Optional token IDs for alignment
77
+
78
+ Returns:
79
+ List of (start_char, end_char) spans
80
+ """
81
+ spans = []
82
+
83
+ # Pattern 1: $...$ (inline math)
84
+ for match in re.finditer(r'\$(.+?)\$', text, re.DOTALL):
85
+ spans.append((match.start(), match.end()))
86
+
87
+ # Pattern 2: $$...$$ (display math)
88
+ for match in re.finditer(r'\$\$(.+?)\$\$', text, re.DOTALL):
89
+ spans.append((match.start(), match.end()))
90
+
91
+ # Pattern 3: \[...\] (LaTeX display math)
92
+ for match in re.finditer(r'\\\[(.+?)\\\]', text, re.DOTALL):
93
+ spans.append((match.start(), match.end()))
94
+
95
+ # Pattern 4: \(...\) (LaTeX inline math)
96
+ for match in re.finditer(r'\\\((.+?)\\\)', text, re.DOTALL):
97
+ spans.append((match.start(), match.end()))
98
+
99
+ return spans
100
+
101
+ def forward(
102
+ self,
103
+ x: torch.Tensor,
104
+ text: Optional[List[str]] = None,
105
+ token_spans: Optional[List[List[Tuple[int, int]]]] = None,
106
+ ) -> torch.Tensor:
107
+ """
108
+ Forward pass through the equation module.
109
+
110
+ Args:
111
+ x: Input tensor (batch, seq_len, d_model)
112
+ text: Optional original text strings (for delimiter-based detection)
113
+ token_spans: Optional pre-computed token-level equation spans
114
+ Each element: list of (start_token, end_token) for that batch item
115
+
116
+ Returns:
117
+ Equation-enhanced representation (batch, seq_len, d_model)
118
+ """
119
+ batch, seq_len, d_model = x.shape
120
+
121
+ # Detect equation spans
122
+ if token_spans is None and text is not None:
123
+ # Use delimiter-based detection (requires text)
124
+ token_spans = []
125
+ for b in range(batch):
126
+ char_spans = self.detect_equation_spans(text[b])
127
+ # Convert char spans to token spans (simplified - assumes 1 char ≈ 1 token)
128
+ # In practice, would need proper tokenization alignment
129
+ token_spans_b = []
130
+ for start_char, end_char in char_spans:
131
+ # Rough approximation: divide by average chars per token (~4)
132
+ start_token = max(0, start_char // 4)
133
+ end_token = min(seq_len, end_char // 4 + 1)
134
+ token_spans_b.append((start_token, end_token))
135
+ token_spans.append(token_spans_b)
136
+ elif token_spans is None:
137
+ # Fallback: use learned detector
138
+ token_spans = self._learned_span_detection(x)
139
+
140
+ # Process each batch item
141
+ output = x.clone()
142
+
143
+ for b in range(batch):
144
+ spans_b = token_spans[b] if b < len(token_spans) else []
145
+
146
+ for start_tok, end_tok in spans_b:
147
+ if end_tok <= start_tok:
148
+ continue
149
+
150
+ # Extract equation segment
151
+ eq_segment = x[b:b+1, start_tok:end_tok, :] # (1, seg_len, d_model)
152
+
153
+ # Apply equation-specific transformer
154
+ eq_encoded = self.equation_encoder(eq_segment)
155
+
156
+ # Merge with original
157
+ merged = torch.cat([eq_segment, eq_encoded], dim=-1)
158
+ merged = self.merge(merged)
159
+
160
+ # Place back in output
161
+ output[b:b+1, start_tok:end_tok, :] = merged
162
+
163
+ return output
164
+
165
+ def _learned_span_detection(
166
+ self,
167
+ x: torch.Tensor,
168
+ ) -> List[List[Tuple[int, int]]]:
169
+ """
170
+ Use learned detector to find equation spans when delimiters missing.
171
+ Simple thresholding on span_detector output.
172
+
173
+ Args:
174
+ x: Input tensor (batch, seq_len, d_model)
175
+
176
+ Returns:
177
+ List of token spans per batch item
178
+ """
179
+ batch, seq_len, _ = x.shape
180
+
181
+ # Compute equation probability per token
182
+ eq_probs = torch.sigmoid(self.span_detector(x)) # (batch, seq_len, 1)
183
+ eq_probs = eq_probs.squeeze(-1) # (batch, seq_len)
184
+
185
+ # Threshold
186
+ threshold = 0.5
187
+ spans = []
188
+
189
+ for b in range(batch):
190
+ probs = eq_probs[b]
191
+ is_equation = (probs > threshold).cpu().numpy()
192
+
193
+ # Find contiguous spans
194
+ span_list = []
195
+ in_span = False
196
+ start = 0
197
+
198
+ for t in range(seq_len):
199
+ if is_equation[t] and not in_span:
200
+ start = t
201
+ in_span = True
202
+ elif not is_equation[t] and in_span:
203
+ span_list.append((start, t))
204
+ in_span = False
205
+
206
+ if in_span:
207
+ span_list.append((start, seq_len))
208
+
209
+ spans.append(span_list)
210
+
211
+ return spans
212
+
213
+ def compute_equation_loss(
214
+ self,
215
+ x: torch.Tensor,
216
+ equation_mask: torch.Tensor,
217
+ ) -> torch.Tensor:
218
+ """
219
+ Compute auxiliary loss for equation detection training.
220
+
221
+ Args:
222
+ x: Input tensor (batch, seq_len, d_model)
223
+ equation_mask: Ground truth equation mask (batch, seq_len), 1 if token is in equation
224
+
225
+ Returns:
226
+ Binary cross-entropy loss for equation detection
227
+ """
228
+ logits = self.span_detector(x).squeeze(-1) # (batch, seq_len)
229
+ loss = F.binary_cross_entropy_with_logits(
230
+ logits,
231
+ equation_mask.float(),
232
+ )
233
+ return loss
234
+
235
+
236
+ def test_equation_module():
237
+ """Test EquationModule."""
238
+ d_model = 512
239
+ batch_size = 2
240
+ seq_len = 128
241
+
242
+ module = EquationModule(d_model)
243
+
244
+ x = torch.randn(batch_size, seq_len, d_model)
245
+ text = [
246
+ "The energy is $E = mc^2$ and momentum is $p = mv$.",
247
+ "Equation: \[ F = ma \] and also $a^2 + b^2 = c^2$."
248
+ ]
249
+
250
+ output = module(x, text=text)
251
+ print(f"Input shape: {x.shape}")
252
+ print(f"Output shape: {output.shape}")
253
+ assert output.shape == x.shape
254
+
255
+ # Test equation loss
256
+ equation_mask = torch.zeros(batch_size, seq_len)
257
+ equation_mask[0, 10:15] = 1.0 # Simulate equation span
258
+ equation_mask[1, 5:12] = 1.0
259
+ loss = module.compute_equation_loss(x, equation_mask)
260
+ print(f"Equation loss: {loss.item():.4f}")
261
+
262
+ print("EquationModule test passed!")
263
+
264
+
265
+ if __name__ == "__main__":
266
+ test_equation_module()
models/science_modules/molecular_module.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MolecularModule: Domain knowledge for chemistry and biology.
3
+ Element embeddings, SMILES understanding, bond types, amino acids.
4
+ """
5
+
6
+ import re
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Optional, Tuple, List
11
+
12
+
13
+ class MolecularModule(nn.Module):
14
+ """
15
+ Domain knowledge for chemistry and biology.
16
+ - All 118 elements as learned embeddings with properties
17
+ (atomic number, mass, electronegativity, valence electrons)
18
+ - SMILES string understanding for molecular structures
19
+ - Bond type awareness (covalent, ionic, hydrogen, van der Waals)
20
+ - Amino acid sequence understanding for biology/zoology
21
+ - Molecular formula → property reasoning
22
+ """
23
+
24
+ def __init__(self, d_model: int, num_elements: int = 118):
25
+ """
26
+ Initialize MolecularModule.
27
+
28
+ Args:
29
+ d_model: Model dimension
30
+ num_elements: Number of chemical elements (default 118)
31
+ """
32
+ super().__init__()
33
+ self.d_model = d_model
34
+ self.num_elements = num_elements
35
+
36
+ # Element embeddings — 118 elements
37
+ self.element_embed = nn.Embedding(num_elements + 1, d_model) # +1 for unknown
38
+
39
+ # Element property encoder (12 properties)
40
+ # [atomic_number, mass, electronegativity, valence_e, period, group,
41
+ # atomic_radius, ionization_energy, electron_affinity, density,
42
+ # melting_point, boiling_point]
43
+ self.property_proj = nn.Linear(12, d_model)
44
+
45
+ # Bond type embeddings (8 types)
46
+ # 0: none, 1: single, 2: double, 3: triple, 4: aromatic,
47
+ # 5: ionic, 6: hydrogen, 7: van der waals
48
+ self.bond_embed = nn.Embedding(8, d_model)
49
+
50
+ # Amino acid embeddings (20 standard + special)
51
+ self.amino_acid_vocab = 25 # 20 standard + stop + start + unknown + special
52
+ self.amino_embed = nn.Embedding(self.amino_acid_vocab, d_model)
53
+
54
+ # Molecular graph attention (treats molecules as graphs)
55
+ self.mol_attention = nn.MultiheadAttention(
56
+ d_model,
57
+ num_heads=8,
58
+ batch_first=True,
59
+ dropout=0.1,
60
+ )
61
+
62
+ # Property prediction head (for auxiliary tasks)
63
+ self.property_head = nn.Linear(d_model, 12)
64
+
65
+ # Initialize weights
66
+ self._initialize_weights()
67
+
68
+ # Pre-compute element properties (simplified)
69
+ self._init_element_properties()
70
+
71
+ def _initialize_weights(self):
72
+ """Initialize weights."""
73
+ for module in [self.element_embed, self.property_proj, self.bond_embed,
74
+ self.amino_embed, self.mol_attention, self.property_head]:
75
+ if hasattr(module, 'weight'):
76
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
77
+ if hasattr(module, 'bias') and module.bias is not None:
78
+ nn.init.zeros_(module.bias)
79
+
80
+ def _init_element_properties(self):
81
+ """Initialize element property table with approximate values."""
82
+ # This is a simplified version - in practice would load from database
83
+ # Properties: [atomic_number, mass, electronegativity, valence_e, period, group,
84
+ # atomic_radius, ionization_energy, electron_affinity, density,
85
+ # melting_point, boiling_point]
86
+ properties = torch.zeros(self.num_elements + 1, 12)
87
+
88
+ # Fill in known elements (simplified data for first 20 + some common ones)
89
+ # Real implementation would use a comprehensive chemistry database
90
+ element_data = {
91
+ 1: [1, 1.008, 2.20, 1, 1, 1, 25, 1312, 72.8, 0.0000899, 14, 20],
92
+ 6: [6, 12.011, 2.55, 4, 2, 14, 70, 1086, 153.9, 2.267, 3550, 4027],
93
+ 7: [7, 14.007, 3.04, 5, 2, 15, 65, 1402, 7.0, 0.0012506, 63, 77],
94
+ 8: [8, 15.999, 3.44, 6, 2, 16, 60, 1314, 141.0, 0.001429, 55, 90],
95
+ # ... would fill all 118 elements
96
+ }
97
+
98
+ for z, props in element_data.items():
99
+ properties[z] = torch.tensor(props)
100
+
101
+ self.register_buffer("element_properties", properties)
102
+
103
+ def detect_molecular_spans(
104
+ self,
105
+ text: str,
106
+ ) -> List[Tuple[int, int, str]]:
107
+ """
108
+ Detect molecular/chemical spans in text.
109
+
110
+ Args:
111
+ text: Input text string
112
+
113
+ Returns:
114
+ List of (start_char, end_char, span_type)
115
+ span_type: "formula", "smiles", "amino_acid"
116
+ """
117
+ spans = []
118
+
119
+ # Chemical formulas: H2O, CO2, C6H12O6, NaCl, HCl
120
+ formula_pattern = r'\b([A-Z][a-z]?\d*)+(?:[A-Z][a-z]?\d*)*\b'
121
+ for match in re.finditer(formula_pattern, text):
122
+ # Filter out single letters that are not formulas
123
+ span = match.group()
124
+ if len(span) > 1 or span.isupper():
125
+ spans.append((match.start(), match.end(), "formula"))
126
+
127
+ # SMILES patterns (simplified detection)
128
+ # Contains: =, #, @, [], (), numbers in sequence
129
+ smiles_hints = ['=', '#', '@', '[', ']', '(', ')']
130
+ words = re.findall(r'\S+', text)
131
+ for word in words:
132
+ if any(hint in word for hint in smiles_hints) and len(word) > 3:
133
+ # Find position in text
134
+ pos = text.find(word)
135
+ if pos >= 0:
136
+ spans.append((pos, pos + len(word), "smiles"))
137
+
138
+ # Amino acid sequences (single letters, length > 5)
139
+ aa_pattern = r'\b([ACDEFGHIKLMNPQRSTVWY]{6,})\b'
140
+ for match in re.finditer(aa_pattern, text.upper()):
141
+ spans.append((match.start(), match.end(), "amino_acid"))
142
+
143
+ return spans
144
+
145
+ def encode_molecule(
146
+ self,
147
+ formula: str,
148
+ ) -> torch.Tensor:
149
+ """
150
+ Encode a molecular formula into embedding.
151
+
152
+ Args:
153
+ formula: Chemical formula string (e.g., "C6H12O6")
154
+
155
+ Returns:
156
+ Molecule embedding (d_model,)
157
+ """
158
+ # Parse formula into elements and counts
159
+ # Simplified parser - real would handle nested parentheses
160
+ pattern = r'([A-Z][a-z]?)(\d*)'
161
+ matches = re.findall(pattern, formula)
162
+
163
+ device = self.element_embed.weight.device
164
+ embeddings = []
165
+ weights = []
166
+
167
+ for element, count_str in matches:
168
+ # Get element atomic number (simplified mapping)
169
+ element_map = {
170
+ 'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8,
171
+ 'F': 9, 'Ne': 10, 'Na': 11, 'Mg': 12, 'Al': 13, 'Si': 14, 'P': 15,
172
+ 'S': 16, 'Cl': 17, 'Ar': 18, 'K': 19, 'Ca': 20,
173
+ # ... extend as needed
174
+ }
175
+ z = element_map.get(element, 0) # 0 = unknown
176
+
177
+ count = int(count_str) if count_str else 1
178
+
179
+ # Get element embedding
180
+ elem_emb = self.element_embed(torch.tensor(z, device=device))
181
+
182
+ # Get properties and project
183
+ props = self.element_properties[z].unsqueeze(0) # (1, 12)
184
+ props_emb = self.property_proj(props).squeeze(0)
185
+
186
+ # Combine
187
+ combined = elem_emb + props_emb
188
+ embeddings.append(combined)
189
+ weights.append(count)
190
+
191
+ if not embeddings:
192
+ # Return zero embedding
193
+ return torch.zeros(self.d_model, device=device)
194
+
195
+ # Weighted average
196
+ embeddings = torch.stack(embeddings)
197
+ weights = torch.tensor(weights, dtype=torch.float32, device=device)
198
+ weights = weights / weights.sum()
199
+
200
+ return (embeddings * weights.unsqueeze(-1)).sum(dim=0)
201
+
202
+ def forward(
203
+ self,
204
+ x: torch.Tensor,
205
+ text: Optional[List[str]] = None,
206
+ molecular_spans: Optional[List[List[Tuple[int, int, str]]]] = None,
207
+ ) -> torch.Tensor:
208
+ """
209
+ Forward pass through molecular module.
210
+
211
+ Args:
212
+ x: Input tensor (batch, seq_len, d_model)
213
+ text: Optional original text strings
214
+ molecular_spans: Optional pre-computed molecular spans per batch
215
+
216
+ Returns:
217
+ Molecular-enhanced representation (batch, seq_len, d_model)
218
+ """
219
+ batch, seq_len, d_model = x.shape
220
+ device = x.device
221
+
222
+ # Detect molecular spans
223
+ if molecular_spans is None and text is not None:
224
+ molecular_spans = []
225
+ for b in range(batch):
226
+ spans = self.detect_molecular_spans(text[b])
227
+ # Convert char spans to token spans
228
+ token_spans = []
229
+ for start_char, end_char, span_type in spans:
230
+ start_tok = max(0, start_char // 4)
231
+ end_tok = min(seq_len, end_char // 4 + 1)
232
+ token_spans.append((start_tok, end_tok, span_type))
233
+ molecular_spans.append(token_spans)
234
+
235
+ # Enhance molecular spans
236
+ output = x.clone()
237
+
238
+ if molecular_spans:
239
+ for b in range(batch):
240
+ spans_b = molecular_spans[b] if b < len(molecular_spans) else []
241
+
242
+ for start_tok, end_tok, span_type in spans_b:
243
+ if end_tok <= start_tok:
244
+ continue
245
+
246
+ span_slice = x[b, start_tok:end_tok, :]
247
+
248
+ if span_type == "formula":
249
+ # Extract formula from text if available
250
+ if text:
251
+ formula = text[b][start_tok*4:end_tok*4] # rough extraction
252
+ mol_emb = self.encode_molecule(formula)
253
+ else:
254
+ mol_emb = torch.randn(d_model, device=device)
255
+
256
+ # Add molecular embedding to first token
257
+ output[b, start_tok, :] += mol_emb
258
+
259
+ elif span_type == "amino_acid":
260
+ # Encode as amino acid sequence
261
+ # Simplified: treat each letter as amino acid
262
+ seq_len_span = end_tok - start_tok
263
+ aa_ids = torch.randint(0, 20, (seq_len_span,), device=device)
264
+ aa_emb = self.amino_embed(aa_ids) # (seq_len_span, d_model)
265
+ output[b, start_tok:end_tok, :] += aa_emb
266
+
267
+ elif span_type == "smiles":
268
+ # For SMILES, apply graph attention (simplified)
269
+ # Treat each character as a node
270
+ seq_len_span = end_tok - start_tok
271
+ if seq_len_span > 1:
272
+ # Self-attention over the span
273
+ attn_out, _ = self.mol_attention(
274
+ span_slice.unsqueeze(0),
275
+ span_slice.unsqueeze(0),
276
+ span_slice.unsqueeze(0),
277
+ )
278
+ output[b, start_tok:end_tok, :] += attn_out.squeeze(0)
279
+
280
+ return output
281
+
282
+ def compute_property_loss(
283
+ self,
284
+ x: torch.Tensor,
285
+ element_ids: torch.Tensor,
286
+ target_properties: torch.Tensor,
287
+ ) -> torch.Tensor:
288
+ """
289
+ Compute auxiliary loss for property prediction.
290
+
291
+ Args:
292
+ x: Input tensor (batch, seq_len, d_model)
293
+ element_ids: Element IDs (batch, seq_len)
294
+ target_properties: Target property values (batch, seq_len, 12)
295
+
296
+ Returns:
297
+ MSE loss for property prediction
298
+ """
299
+ # Get element embeddings
300
+ elem_emb = self.element_embed(element_ids)
301
+
302
+ # Predict properties
303
+ pred_props = self.property_head(elem_emb)
304
+
305
+ # Compute loss
306
+ loss = F.mse_loss(pred_props, target_properties)
307
+ return loss
308
+
309
+
310
+ def test_molecular_module():
311
+ """Test MolecularModule."""
312
+ d_model = 512
313
+ batch_size = 2
314
+ seq_len = 128
315
+
316
+ module = MolecularModule(d_model)
317
+
318
+ x = torch.randn(batch_size, seq_len, d_model)
319
+ text = [
320
+ "Water is H2O. The DNA sequence is ACGTACGTACGT.",
321
+ "Proteins are made of amino acids like ACDEFGH. Benzene is C6H6."
322
+ ]
323
+
324
+ output = module(x, text=text)
325
+ print(f"Input shape: {x.shape}")
326
+ print(f"Output shape: {output.shape}")
327
+ assert output.shape == x.shape
328
+
329
+ print("MolecularModule test passed!")
330
+
331
+
332
+ if __name__ == "__main__":
333
+ test_molecular_module()
models/science_modules/numerical_module.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NumericalReasoningModule: Handles scientific numerical reasoning.
3
+ Digit-level number encoding, scientific notation, unit awareness.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import re
10
+ from typing import Optional, Tuple, List
11
+
12
+
13
+ class NumericalReasoningModule(nn.Module):
14
+ """
15
+ Handles scientific numerical reasoning.
16
+ - Digit-level number encoding (each digit gets position-aware embedding)
17
+ - Scientific notation understanding (6.02 × 10²³)
18
+ - Unit awareness (meters, joules, moles, kelvin)
19
+ - Order of magnitude reasoning
20
+ - Significant figures tracking
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ d_model: int,
26
+ max_digits: int = 20,
27
+ num_units: int = 256,
28
+ ):
29
+ """
30
+ Initialize NumericalReasoningModule.
31
+
32
+ Args:
33
+ d_model: Model dimension
34
+ max_digits: Maximum number of digits to encode
35
+ num_units: Number of unit types to embed
36
+ """
37
+ super().__init__()
38
+ self.d_model = d_model
39
+ self.max_digits = max_digits
40
+
41
+ # Digit embeddings (0-9)
42
+ self.digit_embed = nn.Embedding(10, 64)
43
+
44
+ # Position embeddings (ones, tens, hundreds...)
45
+ self.position_embed = nn.Embedding(max_digits, 64)
46
+
47
+ # Project digit+position to model dimension
48
+ self.number_proj = nn.Linear(128, d_model)
49
+
50
+ # Unit embedding (SI units + common scientific units)
51
+ self.unit_embed = nn.Embedding(num_units, d_model)
52
+
53
+ # Scientific notation handler
54
+ self.sci_notation = nn.Linear(d_model * 2, d_model)
55
+
56
+ # Magnitude embedding (powers of 10: -10 to +10)
57
+ self.magnitude_embed = nn.Embedding(21, d_model) # -10 to +10
58
+
59
+ # Initialize weights
60
+ self._initialize_weights()
61
+
62
+ def _initialize_weights(self):
63
+ """Initialize weights."""
64
+ for module in [self.digit_embed, self.position_embed, self.number_proj,
65
+ self.unit_embed, self.sci_notation, self.magnitude_embed]:
66
+ if hasattr(module, 'weight'):
67
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
68
+ if hasattr(module, 'bias') and module.bias is not None:
69
+ nn.init.zeros_(module.bias)
70
+
71
+ def encode_number(
72
+ self,
73
+ number_str: str,
74
+ device: torch.device,
75
+ ) -> torch.Tensor:
76
+ """
77
+ Encode a number string using digit-level encoding.
78
+
79
+ Args:
80
+ number_str: String representation of number (e.g., "123.45e-6")
81
+ device: Torch device
82
+
83
+ Returns:
84
+ Number embedding (d_model,)
85
+ """
86
+ # Extract digits (ignore decimal point, sign, exponent)
87
+ digits = [int(d) for d in re.findall(r'\d', number_str)]
88
+ if not digits:
89
+ digits = [0]
90
+
91
+ # Pad/truncate to max_digits
92
+ if len(digits) > self.max_digits:
93
+ digits = digits[:self.max_digits]
94
+ else:
95
+ digits = digits + [0] * (self.max_digits - len(digits))
96
+
97
+ digits_tensor = torch.tensor(digits, device=device) # (max_digits,)
98
+ positions = torch.arange(self.max_digits, device=device) # (max_digits,)
99
+
100
+ # Embed digits and positions
101
+ digit_emb = self.digit_embed(digits_tensor) # (max_digits, 64)
102
+ pos_emb = self.position_embed(positions) # (max_digits, 64)
103
+
104
+ # Concatenate and project
105
+ combined = torch.cat([digit_emb, pos_emb], dim=-1) # (max_digits, 128)
106
+ number_emb = self.number_proj(combined) # (max_digits, d_model)
107
+
108
+ # Mean pool over positions
109
+ return number_emb.mean(dim=0) # (d_model,)
110
+
111
+ def detect_numbers(
112
+ self,
113
+ text: str,
114
+ ) -> List[Tuple[str, int, int, Optional[str]]]:
115
+ """
116
+ Detect numbers in text with optional units and scientific notation.
117
+
118
+ Returns:
119
+ List of (number_str, start_char, end_char, unit_str)
120
+ """
121
+ # Pattern: number with optional decimal, exponent, and unit
122
+ # Matches: 123, 123.45, 1.23e-4, 6.02×10²³, 100 m, 5.0 J/mol
123
+ pattern = r'(\d+(?:\.\d+)?(?:[eE][+-]?\d+)?(?:×10\^?[+-]?\d+)?)(?:\s*([a-zA-Z°%]+))?'
124
+
125
+ matches = []
126
+ for match in re.finditer(pattern, text):
127
+ number_str = match.group(1)
128
+ unit_str = match.group(2) if match.group(2) else None
129
+ matches.append((number_str, match.start(), match.end(), unit_str))
130
+
131
+ return matches
132
+
133
+ def forward(
134
+ self,
135
+ x: torch.Tensor,
136
+ text: Optional[List[str]] = None,
137
+ number_positions: Optional[List[List[Tuple[int, int, str]]]] = None,
138
+ ) -> torch.Tensor:
139
+ """
140
+ Forward pass through numerical reasoning module.
141
+
142
+ Args:
143
+ x: Input tensor (batch, seq_len, d_model)
144
+ text: Optional original text strings
145
+ number_positions: Optional list of (start_token, end_token, number_str) per batch
146
+
147
+ Returns:
148
+ Numerical-enhanced representation (batch, seq_len, d_model)
149
+ """
150
+ batch, seq_len, d_model = x.shape
151
+ device = x.device
152
+
153
+ # Detect numbers if text provided
154
+ if number_positions is None and text is not None:
155
+ number_positions = []
156
+ for b in range(batch):
157
+ numbers = self.detect_numbers(text[b])
158
+ # Convert char positions to token positions (approximate)
159
+ token_nums = []
160
+ for num_str, start_char, end_char, unit_str in numbers:
161
+ start_tok = max(0, start_char // 4)
162
+ end_tok = min(seq_len, end_char // 4 + 1)
163
+ token_nums.append((start_tok, end_tok, num_str, unit_str))
164
+ number_positions.append(token_nums)
165
+
166
+ # Enhance number spans
167
+ output = x.clone()
168
+
169
+ if number_positions:
170
+ for b in range(batch):
171
+ nums_b = number_positions[b] if b < len(number_positions) else []
172
+
173
+ for start_tok, end_tok, num_str, unit_str in nums_b:
174
+ if end_tok <= start_tok or start_tok >= seq_len:
175
+ continue
176
+
177
+ # Clamp to sequence bounds
178
+ start_tok = min(start_tok, seq_len - 1)
179
+ end_tok = min(end_tok, seq_len)
180
+
181
+ # Encode the number
182
+ number_emb = self.encode_number(num_str, device) # (d_model,)
183
+
184
+ # Add unit embedding if present
185
+ if unit_str:
186
+ # Simple hash-based unit ID (in practice would have unit vocab)
187
+ unit_id = hash(unit_str) % self.unit_embed.num_embeddings
188
+ unit_emb = self.unit_embed(torch.tensor(unit_id, device=device))
189
+ number_emb = number_emb + unit_emb
190
+
191
+ # Add magnitude embedding for scientific notation
192
+ if 'e' in num_str.lower() or '×10' in num_str:
193
+ # Extract exponent
194
+ exp_match = re.search(r'[eE]([+-]?\d+)|×10\^?([+-]?\d+)', num_str)
195
+ if exp_match:
196
+ exp = int(exp_match.group(1) or exp_match.group(2))
197
+ exp = max(-10, min(10, exp)) # Clamp to embedding range
198
+ magnitude_emb = self.magnitude_embed(torch.tensor(exp + 10, device=device))
199
+ number_emb = number_emb + magnitude_emb
200
+
201
+ # Add to the first token of the number span
202
+ output[b, start_tok, :] += number_emb
203
+
204
+ return output
205
+
206
+ def compute_numerical_loss(
207
+ self,
208
+ x: torch.Tensor,
209
+ number_mask: torch.Tensor,
210
+ target_values: torch.Tensor,
211
+ ) -> torch.Tensor:
212
+ """
213
+ Compute auxiliary loss for numerical reasoning.
214
+
215
+ Args:
216
+ x: Input tensor (batch, seq_len, d_model)
217
+ number_mask: Mask for number tokens (batch, seq_len)
218
+ target_values: Target numeric values (batch, seq_len) or None
219
+
220
+ Returns:
221
+ MSE loss for value prediction (simplified)
222
+ """
223
+ # This is a simplified loss - in practice would have a value prediction head
224
+ # For now, return a small regularization loss on number embeddings
225
+ return 0.0
226
+
227
+
228
+ def test_numerical_module():
229
+ """Test NumericalReasoningModule."""
230
+ d_model = 512
231
+ batch_size = 2
232
+ seq_len = 128
233
+
234
+ module = NumericalReasoningModule(d_model)
235
+
236
+ x = torch.randn(batch_size, seq_len, d_model)
237
+ text = [
238
+ "The speed of light is 2.998×10^8 m/s and Planck's constant is 6.626×10^-34 J·s.",
239
+ "Calculate: 123.45 + 67.89 = ? The answer is 191.34."
240
+ ]
241
+
242
+ output = module(x, text=text)
243
+ print(f"Input shape: {x.shape}")
244
+ print(f"Output shape: {output.shape}")
245
+ assert output.shape == x.shape
246
+
247
+ print("NumericalReasoningModule test passed!")
248
+
249
+
250
+ if __name__ == "__main__":
251
+ test_numerical_module()
models/scigate_ffn.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SciGateFFN: Science-aware gated feed-forward network.
3
+ Learns to activate different FFN pathways based on science domain.
4
+ Uses hybrid routing: explicit domain tags preferred, fallback to learned classifier.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Optional, Tuple
11
+
12
+
13
+ class SciGateFFN(nn.Module):
14
+ """
15
+ Gated FFN with science domain routing.
16
+ Learns to activate different FFN pathways for different science domains.
17
+ Gate is conditioned on detected domain (math, chemistry, biology etc).
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ d_model: int,
23
+ expansion: int = 4,
24
+ num_domains: int = 7,
25
+ use_domain_tags: bool = True,
26
+ ):
27
+ """
28
+ Initialize SciGateFFN.
29
+
30
+ Args:
31
+ d_model: Model dimension
32
+ expansion: FFN expansion factor (default 4)
33
+ num_domains: Number of science domains (7)
34
+ use_domain_tags: Whether to use explicit domain tags for routing
35
+ """
36
+ super().__init__()
37
+ self.d_model = d_model
38
+ self.expansion = expansion
39
+ self.num_domains = num_domains
40
+ self.use_domain_tags = use_domain_tags
41
+
42
+ hidden_dim = d_model * expansion
43
+
44
+ # Standard SwiGLU architecture: up_proj splits into two paths
45
+ self.up_proj = nn.Linear(d_model, hidden_dim * 2, bias=False)
46
+ self.down_proj = nn.Linear(hidden_dim, d_model, bias=False)
47
+
48
+ # Domain-specific scaling factors (learnable)
49
+ # Shape: (num_domains, hidden_dim)
50
+ self.domain_gate = nn.Linear(num_domains, hidden_dim, bias=True)
51
+
52
+ # Fallback domain classifier (when tags not present)
53
+ # Simple linear classifier based on sequence representation
54
+ self.fallback_classifier = nn.Sequential(
55
+ nn.Linear(d_model, d_model // 2),
56
+ nn.SiLU(),
57
+ nn.Linear(d_model // 2, num_domains),
58
+ )
59
+
60
+ # Initialize weights
61
+ self._initialize_weights()
62
+
63
+ def _initialize_weights(self):
64
+ """Initialize weights."""
65
+ for module in [self.up_proj, self.down_proj, self.domain_gate, self.fallback_classifier]:
66
+ if hasattr(module, 'weight'):
67
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
68
+ if hasattr(module, 'bias') and module.bias is not None:
69
+ nn.init.zeros_(module.bias)
70
+
71
+ def get_domain_one_hot(
72
+ self,
73
+ domain_ids: Optional[torch.Tensor] = None,
74
+ domain_tags: Optional[torch.Tensor] = None,
75
+ hidden_states: Optional[torch.Tensor] = None,
76
+ ) -> torch.Tensor:
77
+ """
78
+ Get domain one-hot vector for routing.
79
+
80
+ Hybrid strategy:
81
+ 1. If domain_tags provided (explicit [MATH], [CHEM] etc), use those
82
+ 2. If domain_ids provided (from data loader), use those
83
+ 3. Fallback: classify from hidden_states
84
+
85
+ Args:
86
+ domain_ids: Tensor of domain IDs (batch, seq_len) or (batch,)
87
+ domain_tags: Boolean mask for domain tags (batch, seq_len, num_domains)
88
+ hidden_states: Hidden states for fallback classification (batch, seq_len, d_model)
89
+
90
+ Returns:
91
+ domain_one_hot: (batch, seq_len, num_domains)
92
+ """
93
+ batch, seq_len, _ = hidden_states.shape if hidden_states is not None else (0, 0, 0)
94
+
95
+ if domain_tags is not None and domain_tags.any():
96
+ # Use explicit domain tags (one-hot already)
97
+ return domain_tags.float()
98
+ elif domain_ids is not None:
99
+ # Convert domain IDs to one-hot
100
+ if domain_ids.dim() == 1:
101
+ # Same domain for entire sequence
102
+ domain_one_hot = F.one_hot(domain_ids, num_classes=self.num_domains)
103
+ # Expand to sequence length
104
+ domain_one_hot = domain_one_hot.unsqueeze(1).expand(-1, seq_len, -1)
105
+ else:
106
+ # Per-token domain IDs
107
+ domain_one_hot = F.one_hot(domain_ids, num_classes=self.num_domains)
108
+ return domain_one_hot.float()
109
+ elif hidden_states is not None:
110
+ # Fallback: classify domain from hidden states
111
+ # Use mean pooling over sequence
112
+ pooled = hidden_states.mean(dim=1) # (batch, d_model)
113
+ domain_logits = self.fallback_classifier(pooled) # (batch, num_domains)
114
+ domain_probs = F.softmax(domain_logits, dim=-1)
115
+ # Expand to sequence length
116
+ return domain_probs.unsqueeze(1).expand(-1, seq_len, -1)
117
+ else:
118
+ # Uniform distribution (no domain info)
119
+ uniform = torch.ones(batch, seq_len, self.num_domains, device=hidden_states.device if hidden_states is not None else 'cpu')
120
+ return uniform / self.num_domains
121
+
122
+ def forward(
123
+ self,
124
+ x: torch.Tensor,
125
+ domain_ids: Optional[torch.Tensor] = None,
126
+ domain_tags: Optional[torch.Tensor] = None,
127
+ ) -> torch.Tensor:
128
+ """
129
+ Forward pass with domain-aware gating.
130
+
131
+ Args:
132
+ x: Input tensor (batch, seq_len, d_model)
133
+ domain_ids: Optional domain IDs (batch,) or (batch, seq_len)
134
+ domain_tags: Optional domain tag mask (batch, seq_len, num_domains)
135
+
136
+ Returns:
137
+ Output tensor (batch, seq_len, d_model)
138
+ """
139
+ batch, seq_len, d_model = x.shape
140
+
141
+ # Get domain routing weights
142
+ domain_weights = self.get_domain_one_hot(domain_ids, domain_tags, x)
143
+ # Shape: (batch, seq_len, num_domains)
144
+
145
+ # Project to hidden dimension
146
+ up = self.up_proj(x) # (batch, seq_len, hidden_dim * 2)
147
+ up1, up2 = up.chunk(2, dim=-1) # Each: (batch, seq_len, hidden_dim)
148
+
149
+ # Apply SwiGLU activation
150
+ hidden = up1 * F.silu(up2) # (batch, seq_len, hidden_dim)
151
+
152
+ # Apply domain-specific scaling
153
+ # domain_weights: (batch, seq_len, num_domains)
154
+ # self.domain_gate.weight: (hidden_dim, num_domains) - Linear weight shape
155
+ # einsum: (batch, seq_len, num_domains) * (hidden_dim, num_domains) -> (batch, seq_len, hidden_dim)
156
+ domain_scaling = torch.einsum(
157
+ "bsd,hd->bsh",
158
+ domain_weights,
159
+ self.domain_gate.weight # (hidden_dim, num_domains)
160
+ )
161
+ # domain_scaling: (batch, seq_len, hidden_dim)
162
+
163
+ # Apply domain scaling (multiplicative gating)
164
+ hidden = hidden * domain_scaling
165
+
166
+ # Project back to model dimension
167
+ output = self.down_proj(hidden)
168
+
169
+ return output
170
+
171
+
172
+ def test_scigate_ffn():
173
+ """Test SciGateFFN."""
174
+ batch_size = 2
175
+ seq_len = 128
176
+ d_model = 4096
177
+ num_domains = 7
178
+
179
+ ffn = SciGateFFN(d_model, expansion=4, num_domains=num_domains)
180
+
181
+ # Test with no domain info (fallback)
182
+ x = torch.randn(batch_size, seq_len, d_model)
183
+ output = ffn(x)
184
+ print(f"Input shape: {x.shape}")
185
+ print(f"Output shape: {output.shape}")
186
+ assert output.shape == x.shape
187
+
188
+ # Test with explicit domain IDs
189
+ domain_ids = torch.randint(0, num_domains, (batch_size,))
190
+ output2 = ffn(x, domain_ids=domain_ids)
191
+ assert output2.shape == x.shape
192
+
193
+ # Test with domain tags
194
+ domain_tags = torch.zeros(batch_size, seq_len, num_domains)
195
+ domain_tags[:, :, 0] = 1.0 # All math
196
+ output3 = ffn(x, domain_tags=domain_tags)
197
+ assert output3.shape == x.shape
198
+
199
+ print("SciGateFFN test passed!")
200
+
201
+
202
+ if __name__ == "__main__":
203
+ test_scigate_ffn()
models/ssm_layer.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VortexSSM: Selective State-Space Layer
3
+ Simplified Mamba-style SSM with input-dependent selection.
4
+ Provides O(n) complexity for long sequences, ideal for scientific documents.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Optional, Tuple
11
+
12
+
13
+ class VortexSSM(nn.Module):
14
+ """
15
+ Selective state-space layer. Linear complexity O(n) vs attention's O(n²).
16
+ Handles long scientific documents efficiently with input-dependent selection.
17
+
18
+ Architecture based on Mamba but simplified for scientific reasoning tasks.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ d_model: int,
24
+ d_state: int = 16,
25
+ d_conv: int = 4,
26
+ expand: int = 2,
27
+ dt_rank: Optional[int] = None,
28
+ ):
29
+ """
30
+ Initialize VortexSSM.
31
+
32
+ Args:
33
+ d_model: Model dimension
34
+ d_state: State dimension (default 16 for 7B, 32 for 13B)
35
+ d_conv: Convolution kernel size for local context
36
+ expand: Expansion factor for inner dimension
37
+ dt_rank: Rank for delta projection (if None, uses ceil(d_model/16))
38
+ """
39
+ super().__init__()
40
+ self.d_model = d_model
41
+ self.d_state = d_state
42
+ self.d_conv = d_conv
43
+ self.expand = expand
44
+ self.d_inner = d_model * expand
45
+
46
+ if dt_rank is None:
47
+ self.dt_rank = max(1, d_model // 16)
48
+ else:
49
+ self.dt_rank = dt_rank
50
+
51
+ # Input projection: splits into x and z pathways
52
+ self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
53
+
54
+ # Convolution for local context before SSM
55
+ # Depthwise convolution for efficiency
56
+ self.conv1d = nn.Conv1d(
57
+ in_channels=self.d_inner,
58
+ out_channels=self.d_inner,
59
+ kernel_size=d_conv,
60
+ padding=d_conv - 1,
61
+ groups=self.d_inner,
62
+ bias=False,
63
+ )
64
+
65
+ # SSM parameter projections (input-dependent)
66
+ self.x_proj = nn.Linear(self.d_inner, self.dt_rank + 2 * self.d_state, bias=False)
67
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
68
+
69
+ # State matrices (A is log-scale for stability)
70
+ # A is (d_inner, d_state)
71
+ self.A_log = nn.Parameter(torch.randn(self.d_inner, self.d_state))
72
+ self.D = nn.Parameter(torch.randn(self.d_inner))
73
+
74
+ # Output projection
75
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
76
+
77
+ # Initialize weights
78
+ self._initialize_weights()
79
+
80
+ def _initialize_weights(self):
81
+ """Initialize weights properly."""
82
+ # Initialize A_log with negative values for stable discretization
83
+ nn.init.normal_(self.A_log, mean=-4.0, std=0.5)
84
+ nn.init.normal_(self.D, mean=0.0, std=0.1)
85
+
86
+ # Initialize projections with small values
87
+ for module in [self.in_proj, self.x_proj, self.dt_proj, self.conv1d, self.out_proj]:
88
+ if hasattr(module, 'weight'):
89
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
90
+
91
+ def forward(
92
+ self,
93
+ x: torch.Tensor,
94
+ state: Optional[torch.Tensor] = None,
95
+ return_state: bool = False,
96
+ ) -> torch.Tensor:
97
+ """
98
+ Forward pass through the SSM.
99
+
100
+ Args:
101
+ x: Input tensor (batch, seq_len, d_model)
102
+ state: Previous hidden state (batch, d_inner, d_state)
103
+ return_state: If True, return (output, state)
104
+
105
+ Returns:
106
+ Output tensor (batch, seq_len, d_model) or tuple with state
107
+ """
108
+ batch, seq_len, _ = x.shape
109
+ device = x.device
110
+ dtype = x.dtype
111
+
112
+ # Double-check d_inner matches A_log shape
113
+ d_inner = self.d_inner
114
+
115
+ # Project input to inner dimension
116
+ xz = self.in_proj(x) # (batch, seq_len, 2 * d_inner)
117
+ x, z = xz.chunk(2, dim=-1)
118
+
119
+ # Apply 1D convolution for local context
120
+ # Need to transpose for conv1d: (batch, d_inner, seq_len)
121
+ x_conv = x.transpose(1, 2)
122
+ x_conv = self.conv1d(x_conv)[..., :seq_len] # Trim padding
123
+ x = x_conv.transpose(1, 2)
124
+
125
+ # Discretization: compute delta, A, B parameters
126
+ # x_proj produces: delta (dt_rank), B (d_state), C (d_state)
127
+ x_dbl = self.x_proj(x) # (batch, seq_len, dt_rank + 2*d_state)
128
+ (delta, B, C) = torch.split(
129
+ x_dbl,
130
+ [self.dt_rank, self.d_state, self.d_state],
131
+ dim=-1,
132
+ )
133
+
134
+ # Project delta
135
+ delta = self.dt_proj(delta) # (batch, seq_len, d_inner)
136
+ delta = F.softplus(delta)
137
+
138
+ # Compute discretized state recurrence
139
+ # Use scan operation for efficient sequential processing
140
+ if state is None:
141
+ state = torch.zeros(batch, d_inner, self.d_state, device=device, dtype=dtype)
142
+
143
+ # Sequential scan (can be optimized with CUDA kernel)
144
+ output = []
145
+ for t in range(seq_len):
146
+ x_t = x[:, t] # (batch, d_inner)
147
+ delta_t = delta[:, t] # (batch, d_inner)
148
+ B_t = B[:, t] # (batch, d_state)
149
+ C_t = C[:, t] # (batch, d_state)
150
+
151
+ # Discretize A
152
+ A_delta = torch.exp(self.A_log * delta_t.unsqueeze(-1)) # (batch, d_inner, d_state)
153
+
154
+ # State update: state = A_delta * state + B_t * x_t
155
+ # B_t needs to be (batch, d_state) -> (batch, d_inner, d_state) via broadcasting
156
+ state = A_delta * state + B_t.unsqueeze(1) * x_t.unsqueeze(-1)
157
+
158
+ # Output: y = C_t * state + D * x_t
159
+ y = (C_t.unsqueeze(1) * state).sum(dim=-1) + self.D * x_t
160
+ output.append(y)
161
+
162
+ output = torch.stack(output, dim=1) # (batch, seq_len, d_inner)
163
+
164
+ # Apply gating with z
165
+ output = output * F.silu(z)
166
+
167
+ # Project back to model dimension
168
+ output = self.out_proj(output)
169
+
170
+ if return_state:
171
+ return output, state
172
+ return output
173
+
174
+ def step(
175
+ self,
176
+ x: torch.Tensor,
177
+ state: torch.Tensor,
178
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
179
+ """
180
+ Single-step inference for autoregressive decoding.
181
+
182
+ Args:
183
+ x: Input at current step (batch, d_model)
184
+ state: Previous state (batch, d_inner, d_state)
185
+
186
+ Returns:
187
+ output: (batch, d_model)
188
+ new_state: updated state
189
+ """
190
+ batch, _ = x.shape
191
+
192
+ # Project input
193
+ xz = self.in_proj(x.unsqueeze(1)) # Add seq dim
194
+ x, z = xz.chunk(2, dim=-1)
195
+ x = x.squeeze(1)
196
+ z = z.squeeze(1)
197
+
198
+ # No convolution for single step (would need cache)
199
+
200
+ # Compute parameters
201
+ x_dbl = self.x_proj(x.unsqueeze(1)).squeeze(1)
202
+ delta, B, C = torch.split(
203
+ x_dbl,
204
+ [self.dt_rank, self.d_state, self.d_state],
205
+ dim=-1,
206
+ )
207
+ delta = self.dt_proj(delta)
208
+ delta = F.softplus(delta)
209
+
210
+ # Single step discretization
211
+ A_delta = torch.exp(self.A_log * delta.unsqueeze(-1))
212
+ state = A_delta * state + B.unsqueeze(1) * x.unsqueeze(-1)
213
+ y = (C.unsqueeze(1) * state).sum(dim=-1) + self.D * x
214
+ y = y * F.silu(z)
215
+ output = self.out_proj(y)
216
+
217
+ return output, state
218
+
219
+
220
+ def test_vortex_ssm():
221
+ """Test the VortexSSM layer."""
222
+ batch_size = 2
223
+ seq_len = 128
224
+ d_model = 4096
225
+ d_state = 16
226
+
227
+ ssm = VortexSSM(d_model, d_state=d_state)
228
+ x = torch.randn(batch_size, seq_len, d_model)
229
+
230
+ # Forward pass
231
+ output = ssm(x)
232
+ print(f"Input shape: {x.shape}")
233
+ print(f"Output shape: {output.shape}")
234
+ assert output.shape == x.shape, f"Expected {x.shape}, got {output.shape}"
235
+
236
+ # Stateful forward
237
+ state = torch.zeros(batch_size, ssm.d_inner, d_state)
238
+ output2, new_state = ssm(x, state=state, return_state=True)
239
+ print(f"Stateful output shape: {output2.shape}")
240
+ print(f"State shape: {new_state.shape}")
241
+
242
+ # Single step
243
+ x_step = torch.randn(batch_size, d_model)
244
+ output_step, state_step = ssm.step(x_step, state)
245
+ print(f"Step output shape: {output_step.shape}")
246
+ print(f"Step state shape: {state_step.shape}")
247
+
248
+ print("VortexSSM test passed!")
249
+
250
+
251
+ if __name__ == "__main__":
252
+ test_vortex_ssm()
models/vortex_model.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VortexModel: Main model class combining SSM, attention, science modules, and SciGate FFN.
3
+ Implements two block types: SSM-only and attention+science+SciGate FFN.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Optional, Tuple, List, Dict
10
+
11
+ from .ssm_layer import VortexSSM
12
+ from .attention_layer import VortexLocalAttention
13
+ from .scigate_ffn import SciGateFFN
14
+ from .science_modules import (
15
+ EquationModule,
16
+ NumericalReasoningModule,
17
+ CitationModule,
18
+ MolecularModule,
19
+ )
20
+
21
+
22
+ class VortexBlock(nn.Module):
23
+ """
24
+ Two types of blocks:
25
+ 1. SSMBlock: only VortexSSM
26
+ 2. AttentionBlock: VortexLocalAttention + ScienceModules + SciGateFFN
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ config: Dict,
32
+ is_ssm_block: bool = True,
33
+ ):
34
+ """
35
+ Initialize a Vortex block.
36
+
37
+ Args:
38
+ config: Model configuration
39
+ is_ssm_block: If True, this is an SSM-only block; else attention+science+FFN
40
+ """
41
+ super().__init__()
42
+ self.config = config
43
+ self.is_ssm_block = is_ssm_block
44
+ self.d_model = config["d_model"]
45
+
46
+ if is_ssm_block:
47
+ # SSM-only block
48
+ self.ssm = VortexSSM(
49
+ d_model=config["d_model"],
50
+ d_state=config["d_state"],
51
+ d_conv=config["d_conv"],
52
+ )
53
+ self.norm = nn.LayerNorm(config["d_model"])
54
+ else:
55
+ # Attention + Science + FFN block
56
+ self.attn = VortexLocalAttention(
57
+ d_model=config["d_model"],
58
+ num_heads=config["num_heads"],
59
+ window_size=config["window_size"],
60
+ use_flash_attention=config.get("use_flash_attention", True),
61
+ )
62
+ self.attn_norm = nn.LayerNorm(config["d_model"])
63
+
64
+ # Science modules (enabled based on config flags)
65
+ self.equation_module = None
66
+ self.numerical_module = None
67
+ self.citation_module = None
68
+ self.molecular_module = None
69
+
70
+ if config.get("enable_equation_module", True):
71
+ self.equation_module = EquationModule(config["d_model"])
72
+
73
+ if config.get("enable_numerical_module", True):
74
+ self.numerical_module = NumericalReasoningModule(config["d_model"])
75
+
76
+ if config.get("enable_citation_module", True):
77
+ self.citation_module = CitationModule(config["d_model"])
78
+
79
+ if config.get("enable_molecular_module", True):
80
+ self.molecular_module = MolecularModule(config["d_model"])
81
+
82
+ # SciGate FFN
83
+ self.ffn = SciGateFFN(
84
+ d_model=config["d_model"],
85
+ expansion=config["ffn_expansion"],
86
+ num_domains=config["num_domains"],
87
+ )
88
+ self.ffn_norm = nn.LayerNorm(config["d_model"])
89
+
90
+ # Final layer norm for both block types
91
+ self.final_norm = nn.LayerNorm(config["d_model"])
92
+
93
+ def forward(
94
+ self,
95
+ x: torch.Tensor,
96
+ domain_ids: Optional[torch.Tensor] = None,
97
+ domain_tags: Optional[torch.Tensor] = None,
98
+ text: Optional[List[str]] = None,
99
+ attention_mask: Optional[torch.Tensor] = None,
100
+ ) -> torch.Tensor:
101
+ """
102
+ Forward pass through the block.
103
+
104
+ Args:
105
+ x: Input tensor (batch, seq_len, d_model)
106
+ domain_ids: Optional domain IDs for SciGate FFN
107
+ domain_tags: Optional domain tag masks
108
+ text: Optional original text for science module span detection
109
+ attention_mask: Optional attention mask
110
+
111
+ Returns:
112
+ Output tensor (batch, seq_len, d_model)
113
+ """
114
+ residual = x
115
+
116
+ if self.is_ssm_block:
117
+ # SSM-only pathway
118
+ x = self.norm(x)
119
+ x = self.ssm(x)
120
+ x = residual + x
121
+ x = self.final_norm(x)
122
+ else:
123
+ # Attention + Science + FFN pathway
124
+ # Attention
125
+ residual_attn = x
126
+ x = self.attn_norm(x)
127
+ global_mask = self._detect_global_tokens(x) if hasattr(self, '_detect_global_tokens') else None
128
+ x = self.attn(x, global_mask=global_mask, attention_mask=attention_mask)
129
+ x = residual_attn + x
130
+
131
+ # Science modules (applied sequentially)
132
+ if self.equation_module is not None:
133
+ x = x + self.equation_module(x, text=text)
134
+
135
+ if self.numerical_module is not None:
136
+ x = x + self.numerical_module(x, text=text)
137
+
138
+ if self.citation_module is not None:
139
+ x_cited, _ = self.citation_module(x, text=text)
140
+ x = x + x_cited
141
+
142
+ if self.molecular_module is not None:
143
+ x = x + self.molecular_module(x, text=text)
144
+
145
+ # SciGate FFN
146
+ residual_ffn = x
147
+ x = self.ffn_norm(x)
148
+ x = self.ffn(x, domain_ids=domain_ids, domain_tags=domain_tags)
149
+ x = residual_ffn + x
150
+
151
+ x = self.final_norm(x)
152
+
153
+ return x
154
+
155
+ def _detect_global_tokens(self, x: torch.Tensor) -> torch.Tensor:
156
+ """
157
+ Detect global tokens that should attend across the entire sequence.
158
+ Global tokens are those with special domain tags or high norm.
159
+ """
160
+ # Simple heuristic: tokens with large L2 norm are likely special
161
+ norms = torch.norm(x, dim=-1) # (batch, seq_len)
162
+ threshold = torch.quantile(norms, 0.95, dim=-1, keepdim=True)
163
+ global_mask = norms > threshold
164
+
165
+ return global_mask
166
+
167
+
168
+ class VortexModel(nn.Module):
169
+ """
170
+ Main Vortex model combining SSM and attention blocks.
171
+ Supports both 7B and 13B configurations.
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ config: Dict,
177
+ ):
178
+ """
179
+ Initialize VortexModel.
180
+
181
+ Args:
182
+ config: Model configuration (from vortex_7b_config.py or vortex_13b_config.py)
183
+ """
184
+ super().__init__()
185
+ self.config = config
186
+
187
+ # Token embedding
188
+ self.embed_tokens = nn.Embedding(config["vocab_size"], config["d_model"])
189
+
190
+ # Build blocks according to layer ratio
191
+ self.blocks = nn.ModuleList()
192
+ self._build_blocks()
193
+
194
+ # Final layer norm
195
+ self.ln_f = nn.LayerNorm(config["d_model"])
196
+
197
+ # Output projection (weights will be tied by HuggingFace if config.tie_word_embeddings=True)
198
+ self.lm_head = nn.Linear(config["d_model"], config["vocab_size"], bias=False)
199
+
200
+ # Initialize weights
201
+ self._initialize_weights()
202
+
203
+ def _build_blocks(self):
204
+ """Build the sequence of SSM and attention blocks."""
205
+ num_layers = self.config["num_layers"]
206
+ ssm_ratio = self.config["ssm_ratio"]
207
+
208
+ # Calculate number of each block type
209
+ num_ssm_blocks = int(num_layers * ssm_ratio)
210
+ num_attn_blocks = num_layers - num_ssm_blocks
211
+
212
+ # Determine block pattern
213
+ if ssm_ratio == 0.6: # 7B pattern: SSM, SSM, Attn, SSM, SSM, Attn...
214
+ pattern = [0, 0, 1] # 0=SSM, 1=Attn
215
+ # Repeat pattern and fill remaining
216
+ blocks = []
217
+ while len(blocks) < num_layers:
218
+ blocks.extend(pattern[:min(len(pattern), num_layers - len(blocks))])
219
+ else: # 13B pattern: SSM, Attn, SSM, Attn...
220
+ pattern = [0, 1]
221
+ blocks = []
222
+ while len(blocks) < num_layers:
223
+ blocks.extend(pattern[:min(len(pattern), num_layers - len(blocks))])
224
+
225
+ # Ensure exact count
226
+ blocks = blocks[:num_layers]
227
+ assert len(blocks) == num_layers
228
+
229
+ # Create blocks
230
+ for is_attn in blocks:
231
+ block = VortexBlock(
232
+ config=self.config,
233
+ is_ssm_block=not is_attn,
234
+ )
235
+ self.blocks.append(block)
236
+
237
+ print(f"Built {num_layers} layers: {num_ssm_blocks} SSM, {num_attn_blocks} Attention")
238
+
239
+ def _initialize_weights(self):
240
+ """Initialize weights."""
241
+ nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=0.02)
242
+ for block in self.blocks:
243
+ if hasattr(block, 'ssm'):
244
+ block.ssm._initialize_weights()
245
+ if hasattr(block, 'attn'):
246
+ block.attn._initialize_weights()
247
+ if hasattr(block, 'ffn'):
248
+ block.ffn._initialize_weights()
249
+
250
+ def forward(
251
+ self,
252
+ input_ids: torch.Tensor,
253
+ domain_ids: Optional[torch.Tensor] = None,
254
+ domain_tags: Optional[torch.Tensor] = None,
255
+ attention_mask: Optional[torch.Tensor] = None,
256
+ text: Optional[List[str]] = None,
257
+ return_dict: bool = True,
258
+ ) -> torch.Tensor:
259
+ """
260
+ Forward pass through the model.
261
+
262
+ Args:
263
+ input_ids: Token IDs (batch, seq_len)
264
+ domain_ids: Optional domain IDs
265
+ domain_tags: Optional domain tag masks
266
+ attention_mask: Optional attention mask (batch, seq_len)
267
+ text: Optional original text for science modules
268
+ return_dict: Whether to return dict (always returns tensor for now)
269
+
270
+ Returns:
271
+ Logits (batch, seq_len, vocab_size)
272
+ """
273
+ # Embed tokens
274
+ x = self.embed_tokens(input_ids)
275
+
276
+ # Pass through blocks
277
+ for block in self.blocks:
278
+ x = block(
279
+ x,
280
+ domain_ids=domain_ids,
281
+ domain_tags=domain_tags,
282
+ text=text,
283
+ attention_mask=attention_mask,
284
+ )
285
+
286
+ # Final norm
287
+ x = self.ln_f(x)
288
+
289
+ # Project to vocabulary
290
+ logits = self.lm_head(x)
291
+
292
+ if return_dict:
293
+ return {"logits": logits, "last_hidden_state": x}
294
+ return logits
295
+
296
+ def get_num_params(self) -> int:
297
+ """Get total number of parameters."""
298
+ return sum(p.numel() for p in self.parameters())
299
+
300
+ def get_trainable_params(self) -> int:
301
+ """Get number of trainable parameters."""
302
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
303
+
304
+ def estimate_memory_usage(
305
+ self,
306
+ batch_size: int,
307
+ seq_len: int,
308
+ use_gradient_checkpointing: bool = False,
309
+ ) -> Dict[str, float]:
310
+ """
311
+ Estimate memory usage for a given batch size and sequence length.
312
+
313
+ Returns:
314
+ Dictionary with memory estimates in GB
315
+ """
316
+ params = self.get_num_params()
317
+ param_bytes = params * 2 # Assuming bfloat16
318
+
319
+ # Activation memory (rough estimate)
320
+ # Each layer: activations ~ batch * seq_len * d_model * 2
321
+ activations_per_layer = batch_size * seq_len * self.config["d_model"] * 2
322
+ total_activations = activations_per_layer * self.config["num_layers"]
323
+
324
+ # Gradients (same size as parameters)
325
+ gradients = param_bytes
326
+
327
+ # Optimizer states (AdamW: 2x parameters)
328
+ optimizer_states = params * 2 * 2
329
+
330
+ total_memory = (param_bytes + total_activations + gradients + optimizer_states) / 1e9
331
+
332
+ return {
333
+ "parameters_gb": param_bytes / 1e9,
334
+ "activations_gb": total_activations / 1e9,
335
+ "gradients_gb": gradients / 1e9,
336
+ "optimizer_states_gb": optimizer_states / 1e9,
337
+ "total_gb": total_memory,
338
+ }
339
+
340
+
341
+ def test_vortex_model():
342
+ """Test the VortexModel."""
343
+ from configs.vortex_7b_config import VORTEX_7B_CONFIG
344
+
345
+ config = VORTEX_7B_CONFIG.copy()
346
+ # Reduce size for testing
347
+ config["d_model"] = 512
348
+ config["num_layers"] = 4
349
+ config["num_heads"] = 8
350
+ config["vocab_size"] = 1000
351
+
352
+ model = VortexModel(config)
353
+
354
+ batch_size = 2
355
+ seq_len = 128
356
+ input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
357
+
358
+ # Forward pass
359
+ output = model(input_ids)
360
+ logits = output["logits"]
361
+
362
+ print(f"Model parameters: {model.get_num_params():,}")
363
+ print(f"Input shape: {input_ids.shape}")
364
+ print(f"Logits shape: {logits.shape}")
365
+ assert logits.shape == (batch_size, seq_len, config["vocab_size"])
366
+
367
+ # Memory estimate
368
+ mem = model.estimate_memory_usage(batch_size, seq_len)
369
+ print(f"Memory estimate for batch={batch_size}, seq_len={seq_len}:")
370
+ for k, v in mem.items():
371
+ print(f" {k}: {v:.2f} GB")
372
+
373
+ print("VortexModel test passed!")
374
+
375
+
376
+ if __name__ == "__main__":
377
+ test_vortex_model()
mps_optimize.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MPS optimizations for Vortex model on Apple Silicon.
3
+ Uses PyTorch MPS backend with MPS-compatible ops only.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Optional, Dict, Any
9
+
10
+
11
+ def optimize_for_mps(
12
+ model: nn.Module,
13
+ config: Dict,
14
+ use_sdpa: bool = True,
15
+ ) -> nn.Module:
16
+ """
17
+ Apply MPS optimizations to model.
18
+
19
+ Args:
20
+ model: VortexModel
21
+ config: Model config
22
+ use_sdpa: Use PyTorch scaled dot product attention (MPS compatible)
23
+
24
+ Returns:
25
+ Optimized model
26
+ """
27
+ device = torch.device("mps")
28
+
29
+ # Move to MPS
30
+ model = model.to(device)
31
+
32
+ # Set dtype - MPS supports float32 and float16 (bfloat16 limited)
33
+ dtype_str = config.get("dtype", "bfloat16")
34
+ if dtype_str == "bfloat16":
35
+ # MPS has limited bfloat16 support, use float16
36
+ dtype = torch.float16
37
+ else:
38
+ dtype = torch.float32
39
+
40
+ model = model.to(dtype)
41
+
42
+ # Replace Flash Attention with standard SDPA
43
+ if use_sdpa:
44
+ model = _apply_sdpa(model)
45
+ print("Applied PyTorch SDPA for MPS")
46
+
47
+ return model
48
+
49
+
50
+ def _apply_sdpa(model: nn.Module) -> nn.Module:
51
+ """
52
+ Replace custom attention with PyTorch SDPA.
53
+ SDPA is optimized for MPS backend.
54
+ """
55
+ for name, module in model.named_modules():
56
+ if hasattr(module, 'attn') and hasattr(module.attn, 'forward_optimized'):
57
+ # Use the SDPA path
58
+ original_forward = module.attn.forward
59
+
60
+ def sdpa_forward(self, x, *args, **kwargs):
61
+ return self._standard_attention(x, kwargs.get('attention_mask'))
62
+
63
+ module.attn.forward = sdpa_forward.__get__(module.attn, type(module.attn))
64
+
65
+ return model
66
+
67
+
68
+ def get_mps_memory_usage() -> Dict[str, float]:
69
+ """Get current MPS memory usage in GB."""
70
+ if not torch.backends.mps.is_available():
71
+ return {"error": "MPS not available"}
72
+
73
+ # MPS doesn't have direct memory query, use unified memory
74
+ import psutil
75
+ process = psutil.Process()
76
+ memory_info = process.memory_info()
77
+
78
+ return {
79
+ "rss_gb": memory_info.rss / 1e9, # Resident set size
80
+ "vms_gb": memory_info.vms / 1e9, # Virtual memory size
81
+ }
82
+
83
+
84
+ def profile_model_mps(
85
+ model: nn.Module,
86
+ input_ids: torch.Tensor,
87
+ num_warmup: int = 10,
88
+ num_runs: int = 50,
89
+ ) -> Dict[str, float]:
90
+ """
91
+ Profile model performance on MPS.
92
+
93
+ Args:
94
+ model: Model to profile
95
+ input_ids: Example input
96
+ num_warmup: Number of warmup runs
97
+ num_runs: Number of profiling runs
98
+
99
+ Returns:
100
+ Dictionary with timing statistics
101
+ """
102
+ model.eval()
103
+ device = next(model.parameters()).device
104
+ input_ids = input_ids.to(device)
105
+
106
+ # Warmup
107
+ with torch.no_grad():
108
+ for _ in range(num_warmup):
109
+ _ = model(input_ids)
110
+ # MPS is async, need to wait
111
+ if device.type == "mps":
112
+ torch.mps.synchronize()
113
+
114
+ # Profile
115
+ if device.type == "mps":
116
+ torch.mps.synchronize()
117
+ import time
118
+ start = time.time()
119
+
120
+ with torch.no_grad():
121
+ for _ in range(num_runs):
122
+ _ = model(input_ids)
123
+ if device.type == "mps":
124
+ torch.mps.synchronize()
125
+
126
+ elapsed = time.time() - start
127
+
128
+ avg_time = elapsed / num_runs
129
+ tokens_per_sec = input_ids.shape[1] / avg_time
130
+
131
+ return {
132
+ "avg_time_sec": avg_time,
133
+ "tokens_per_sec": tokens_per_sec,
134
+ }
135
+
136
+
137
+ def test_mps_optimize():
138
+ """Test MPS optimizations."""
139
+ if not torch.backends.mps.is_available():
140
+ print("MPS not available, skipping test")
141
+ return
142
+
143
+ from models.vortex_model import VortexModel
144
+ from configs.vortex_7b_config import VORTEX_7B_CONFIG
145
+
146
+ config = VORTEX_7B_CONFIG.copy()
147
+ config["d_model"] = 512
148
+ config["num_layers"] = 2
149
+ config["num_heads"] = 8
150
+ config["vocab_size"] = 1000
151
+
152
+ model = VortexModel(config)
153
+ print(f"Model parameters: {model.get_num_params():,}")
154
+
155
+ # Optimize for MPS
156
+ model = optimize_for_mps(model, config, use_sdpa=True)
157
+
158
+ # Test forward
159
+ batch_size = 2
160
+ seq_len = 128
161
+ input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)).to("mps")
162
+
163
+ with torch.no_grad():
164
+ output = model(input_ids)
165
+ logits = output["logits"]
166
+
167
+ print(f"Output shape: {logits.shape}")
168
+ print("MPS optimize test passed!")
169
+
170
+
171
+ if __name__ == "__main__":
172
+ test_mps_optimize()
push_to_hf.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import os
4
+ from huggingface_hub import HfApi, create_repo
5
+
6
+ def push_to_hf(repo_id, token=None, private=False):
7
+ api = HfApi(token=token)
8
+
9
+ # Create repo if it doesn't exist
10
+ try:
11
+ create_repo(repo_id, repo_type="model", private=private, exist_ok=True)
12
+ print(f"Repository {repo_id} ready")
13
+ except Exception as e:
14
+ print(f"Repo creation note: {e}")
15
+
16
+ # Upload all files in current directory
17
+ cwd = os.getcwd()
18
+ print(f"Uploading from {cwd} to {repo_id}...")
19
+
20
+ try:
21
+ api.upload_folder(
22
+ folder_path=cwd,
23
+ repo_id=repo_id,
24
+ repo_type="model",
25
+ commit_message="Upload Vortex model"
26
+ )
27
+ print(f"Successfully uploaded to {repo_id}")
28
+ except Exception as e:
29
+ print(f"Upload failed: {e}")
30
+ raise
31
+
32
+ if __name__ == "__main__":
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("--repo_id", type=str, required=True, help="HuggingFace repo ID")
35
+ parser.add_argument("--token", type=str, help="HuggingFace token (optional if logged in)")
36
+ parser.add_argument("--private", action="store_true", help="Make repository private")
37
+ args = parser.parse_args()
38
+
39
+ push_to_hf(args.repo_id, args.token, args.private)
requirements.txt ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ torch>=2.2.0
3
+ transformers>=4.40.0
4
+ accelerate>=0.30.0
5
+ datasets>=2.18.0
6
+ tokenizers>=0.19.0
7
+
8
+ # Quantization
9
+ bitsandbytes>=0.43.0
10
+
11
+ # Flash Attention (CUDA only)
12
+ flash-attn>=2.5.0
13
+
14
+ # Scientific computing
15
+ numpy>=1.26.0
16
+ scipy>=1.12.0
17
+ scikit-learn>=1.4.0
18
+
19
+ # Chemistry/Biology
20
+ rdkit>=2023.9.0
21
+ pubchempy>=1.0.4
22
+
23
+ # Web scraping
24
+ arxiv>=2.1.0
25
+ beautifulsoup4>=4.12.0
26
+ requests>=2.31.0
27
+
28
+ # Data processing
29
+ pandas>=2.0.0
30
+ pyarrow>=14.0.0
31
+
32
+ # LaTeX parsing
33
+ pylatexenc>=2.10
34
+
35
+ # Deduplication
36
+ minhash>=0.1.0
37
+
38
+ # Utilities
39
+ tqdm>=4.65.0
40
+ psutil>=5.9.0
41
+ jsonlines>=3.1.0
42
+
43
+ # Optional: wandb for logging
44
+ # wandb>=0.16.0
45
+
46
+ # Development/testing
47
+ pytest>=7.0.0
48
+ black>=23.0.0
49
+ flake8>=6.0.0
50
+ mypy>=1.0.0
science_bench.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Science benchmarks for Vortex model.
3
+ Evaluates performance across 7 science domains.
4
+ """
5
+
6
+ import torch
7
+ from typing import Dict, List, Tuple
8
+ from dataclasses import dataclass
9
+
10
+
11
+ @dataclass
12
+ class BenchmarkResult:
13
+ """Results from a benchmark."""
14
+ domain: str
15
+ accuracy: float
16
+ total_questions: int
17
+ correct_answers: int
18
+ details: List[Dict]
19
+
20
+
21
+ class ScienceBenchmark:
22
+ """
23
+ Base class for science benchmarks.
24
+ """
25
+
26
+ def __init__(self, name: str, domain: str):
27
+ self.name = name
28
+ self.domain = domain
29
+
30
+ def load_questions(self) -> List[Dict]:
31
+ """Load benchmark questions."""
32
+ raise NotImplementedError
33
+
34
+ def evaluate(
35
+ self,
36
+ model,
37
+ tokenizer,
38
+ device: torch.device,
39
+ max_samples: int = 100,
40
+ ) -> BenchmarkResult:
41
+ """
42
+ Evaluate model on benchmark.
43
+
44
+ Args:
45
+ model: Vortex model
46
+ tokenizer: Tokenizer
47
+ device: Torch device
48
+ max_samples: Maximum samples to evaluate
49
+
50
+ Returns:
51
+ BenchmarkResult
52
+ """
53
+ questions = self.load_questions()[:max_samples]
54
+ correct = 0
55
+
56
+ details = []
57
+ for q in questions:
58
+ # Format prompt
59
+ prompt = self.format_prompt(q)
60
+ # Tokenize
61
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
62
+
63
+ # Generate answer
64
+ with torch.no_grad():
65
+ outputs = model.generate(
66
+ **inputs,
67
+ max_new_tokens=50,
68
+ temperature=0.0, # Greedy
69
+ do_sample=False,
70
+ )
71
+
72
+ # Decode
73
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
74
+ answer = self.extract_answer(generated)
75
+
76
+ # Check correctness
77
+ is_correct = self.check_answer(answer, q["answer"])
78
+ if is_correct:
79
+ correct += 1
80
+
81
+ details.append({
82
+ "question": q["question"],
83
+ "expected": q["answer"],
84
+ "generated": answer,
85
+ "correct": is_correct,
86
+ })
87
+
88
+ accuracy = correct / len(questions) if questions else 0.0
89
+ return BenchmarkResult(
90
+ domain=self.domain,
91
+ accuracy=accuracy,
92
+ total_questions=len(questions),
93
+ correct_answers=correct,
94
+ details=details,
95
+ )
96
+
97
+ def format_prompt(self, question: Dict) -> str:
98
+ """Format question into prompt."""
99
+ raise NotImplementedError
100
+
101
+ def extract_answer(self, text: str) -> str:
102
+ """Extract answer from generated text."""
103
+ raise NotImplementedError
104
+
105
+ def check_answer(self, predicted: str, expected: str) -> bool:
106
+ """Check if predicted answer matches expected."""
107
+ raise NotImplementedError
108
+
109
+
110
+ class PhysicsBenchmark(ScienceBenchmark):
111
+ """Physics benchmark (Feynman Questions style)."""
112
+
113
+ def __init__(self):
114
+ super().__init__("physics_benchmark", "physics")
115
+
116
+ def load_questions(self) -> List[Dict]:
117
+ # Placeholder - would load from dataset
118
+ return [
119
+ {
120
+ "question": "What is the formula for kinetic energy?",
121
+ "answer": "KE = 1/2 mv^2",
122
+ "type": "formula",
123
+ },
124
+ {
125
+ "question": "Explain Newton's first law of motion.",
126
+ "answer": "An object at rest stays at rest unless acted upon by a force.",
127
+ "type": "conceptual",
128
+ },
129
+ ]
130
+
131
+ def format_prompt(self, question: Dict) -> str:
132
+ return f"Question: {question['question']}\nAnswer:"
133
+
134
+ def extract_answer(self, text: str) -> str:
135
+ # Extract after "Answer:"
136
+ if "Answer:" in text:
137
+ return text.split("Answer:")[-1].strip()
138
+ return text.strip()
139
+
140
+ def check_answer(self, predicted: str, expected: str) -> bool:
141
+ # Simple string match (would use more sophisticated in practice)
142
+ pred_lower = predicted.lower()
143
+ exp_lower = expected.lower()
144
+ return exp_lower in pred_lower or pred_lower in exp_lower
145
+
146
+
147
+ class MathBenchmark(ScienceBenchmark):
148
+ """Math benchmark (MATH dataset style)."""
149
+
150
+ def __init__(self):
151
+ super().__init__("math_benchmark", "math")
152
+
153
+ def load_questions(self) -> List[Dict]:
154
+ return [
155
+ {
156
+ "question": "Solve for x: 2x + 5 = 15",
157
+ "answer": "x = 5",
158
+ "type": "algebra",
159
+ },
160
+ {
161
+ "question": "What is the derivative of x^2?",
162
+ "answer": "2x",
163
+ "type": "calculus",
164
+ },
165
+ ]
166
+
167
+ def format_prompt(self, question: Dict) -> str:
168
+ return f"Problem: {question['question']}\nSolution:"
169
+
170
+ def extract_answer(self, text: str) -> str:
171
+ if "Solution:" in text:
172
+ return text.split("Solution:")[-1].strip()
173
+ return text.strip()
174
+
175
+ def check_answer(self, predicted: str, expected: str) -> bool:
176
+ # Normalize whitespace and case
177
+ pred = " ".join(predicted.lower().split())
178
+ exp = " ".join(expected.lower().split())
179
+ return pred == exp
180
+
181
+
182
+ class ChemistryBenchmark(ScienceBenchmark):
183
+ """Chemistry benchmark."""
184
+
185
+ def __init__(self):
186
+ super().__init__("chemistry_benchmark", "chemistry")
187
+
188
+ def load_questions(self) -> List[Dict]:
189
+ return [
190
+ {
191
+ "question": "What is the chemical formula for water?",
192
+ "answer": "H2O",
193
+ "type": "factual",
194
+ },
195
+ {
196
+ "question": "How many protons does carbon have?",
197
+ "answer": "6",
198
+ "type": "factual",
199
+ },
200
+ ]
201
+
202
+ def format_prompt(self, question: Dict) -> str:
203
+ return f"Chemistry question: {question['question']}\nAnswer:"
204
+
205
+ def extract_answer(self, text: str) -> str:
206
+ if "Answer:" in text:
207
+ return text.split("Answer:")[-1].strip()
208
+ return text.strip()
209
+
210
+ def check_answer(self, predicted: str, expected: str) -> bool:
211
+ pred = predicted.strip().lower()
212
+ exp = expected.strip().lower()
213
+ return exp in pred
214
+
215
+
216
+ class BiologyBenchmark(ScienceBenchmark):
217
+ """Biology benchmark."""
218
+
219
+ def __init__(self):
220
+ super().__init__("biology_benchmark", "biology")
221
+
222
+ def load_questions(self) -> List[Dict]:
223
+ return [
224
+ {
225
+ "question": "What is the powerhouse of the cell?",
226
+ "answer": "mitochondria",
227
+ "type": "factual",
228
+ },
229
+ {
230
+ "question": "What molecule carries genetic information?",
231
+ "answer": "DNA",
232
+ "type": "factual",
233
+ },
234
+ ]
235
+
236
+ def format_prompt(self, question: Dict) -> str:
237
+ return f"Biology: {question['question']}\nAnswer:"
238
+
239
+ def extract_answer(self, text: str) -> str:
240
+ if "Answer:" in text:
241
+ return text.split("Answer:")[-1].strip()
242
+ return text.strip()
243
+
244
+ def check_answer(self, predicted: str, expected: str) -> bool:
245
+ pred = predicted.strip().lower()
246
+ exp = expected.strip().lower()
247
+ return exp in pred
248
+
249
+
250
+ # Placeholder for other domains
251
+ class EarthScienceBenchmark(ScienceBenchmark):
252
+ def __init__(self):
253
+ super().__init__("earth_science_benchmark", "earth")
254
+
255
+ def load_questions(self) -> List[Dict]:
256
+ return []
257
+
258
+ def format_prompt(self, question: Dict) -> str:
259
+ return f"Earth Science: {question['question']}\nAnswer:"
260
+
261
+ def extract_answer(self, text: str) -> str:
262
+ return text.strip()
263
+
264
+ def check_answer(self, predicted: str, expected: str) -> bool:
265
+ return predicted.strip().lower() == expected.strip().lower()
266
+
267
+
268
+ class SpaceScienceBenchmark(ScienceBenchmark):
269
+ def __init__(self):
270
+ super().__init__("space_science_benchmark", "space")
271
+
272
+ def load_questions(self) -> List[Dict]:
273
+ return []
274
+
275
+ def format_prompt(self, question: Dict) -> str:
276
+ return f"Space Science: {question['question']}\nAnswer:"
277
+
278
+ def extract_answer(self, text: str) -> str:
279
+ return text.strip()
280
+
281
+ def check_answer(self, predicted: str, expected: str) -> bool:
282
+ return predicted.strip().lower() == expected.strip().lower()
283
+
284
+
285
+ class ZoologyBenchmark(ScienceBenchmark):
286
+ def __init__(self):
287
+ super().__init__("zoology_benchmark", "zoology")
288
+
289
+ def load_questions(self) -> List[Dict]:
290
+ return []
291
+
292
+ def format_prompt(self, question: Dict) -> str:
293
+ return f"Zoology: {question['question']}\nAnswer:"
294
+
295
+ def extract_answer(self, text: str) -> str:
296
+ return text.strip()
297
+
298
+ def check_answer(self, predicted: str, expected: str) -> bool:
299
+ return predicted.strip().lower() == expected.strip().lower()
300
+
301
+
302
+ def run_all_benchmarks(
303
+ model,
304
+ tokenizer,
305
+ device: torch.device,
306
+ max_samples_per_domain: int = 50,
307
+ ) -> Dict[str, BenchmarkResult]:
308
+ """
309
+ Run all benchmarks and return results.
310
+
311
+ Args:
312
+ model: Vortex model
313
+ tokenizer: Tokenizer
314
+ device: Torch device
315
+ max_samples_per_domain: Max samples per domain
316
+
317
+ Returns:
318
+ Dictionary mapping domain to results
319
+ """
320
+ benchmarks = [
321
+ PhysicsBenchmark(),
322
+ MathBenchmark(),
323
+ ChemistryBenchmark(),
324
+ BiologyBenchmark(),
325
+ EarthScienceBenchmark(),
326
+ SpaceScienceBenchmark(),
327
+ ZoologyBenchmark(),
328
+ ]
329
+
330
+ results = {}
331
+ for bench in benchmarks:
332
+ print(f"Running {bench.name}...")
333
+ result = bench.evaluate(model, tokenizer, device, max_samples=max_samples_per_domain)
334
+ results[bench.domain] = result
335
+ print(f" Accuracy: {result.accuracy:.2%} ({result.correct_answers}/{result.total_questions})")
336
+
337
+ return results
338
+
339
+
340
+ def print_summary(results: Dict[str, BenchmarkResult]):
341
+ """Print summary of benchmark results."""
342
+ print("\n" + "="*60)
343
+ print("BENCHMARK RESULTS")
344
+ print("="*60)
345
+
346
+ for domain, result in results.items():
347
+ print(f"{domain:15} {result.accuracy:6.2%} ({result.correct_answers}/{result.total_questions})")
348
+
349
+ # Overall average
350
+ all_accuracies = [r.accuracy for r in results.values() if r.total_questions > 0]
351
+ if all_accuracies:
352
+ avg = sum(all_accuracies) / len(all_accuracies)
353
+ print(f"{'OVERALL':15} {avg:6.2%}")
354
+ print("="*60)
355
+
356
+
357
+ if __name__ == "__main__":
358
+ # Quick test
359
+ print("This script benchmarks the model across science domains.")
360
+ print("To run full benchmarks, integrate with a trained model.")
test_model.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Comprehensive unit tests for Vortex model components.
4
+ Run with: python -m pytest test_model.py -v
5
+ """
6
+
7
+ import pytest
8
+ import torch
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ # Add Vortex to path
13
+ sys.path.insert(0, str(Path(__file__).parent))
14
+
15
+
16
+ def test_tokenizer():
17
+ """Test VortexScienceTokenizer."""
18
+ from tokenizer.vortex_tokenizer import VortexScienceTokenizer
19
+ from configs.vortex_7b_config import VORTEX_7B_CONFIG
20
+
21
+ tokenizer = VortexScienceTokenizer(VORTEX_7B_CONFIG)
22
+
23
+ # Test encoding/decoding
24
+ text = "The equation is $E = mc^2$ and H2O is water."
25
+ encoded = tokenizer.encode(text, return_tensors="pt")
26
+ assert "input_ids" in encoded
27
+ assert encoded["input_ids"].shape[0] == 1 # batch dim
28
+
29
+ decoded = tokenizer.decode(encoded["input_ids"][0].tolist())
30
+ assert isinstance(decoded, str)
31
+ print("✓ Tokenizer test passed")
32
+
33
+
34
+ def test_ssm_layer():
35
+ """Test VortexSSM."""
36
+ from models.ssm_layer import VortexSSM
37
+
38
+ batch_size = 2
39
+ seq_len = 64
40
+ d_model = 512
41
+ d_state = 16
42
+
43
+ ssm = VortexSSM(d_model, d_state=d_state)
44
+ x = torch.randn(batch_size, seq_len, d_model)
45
+
46
+ # Forward pass
47
+ output = ssm(x)
48
+ assert output.shape == x.shape
49
+
50
+ # Stateful forward
51
+ state = torch.zeros(batch_size, ssm.d_inner, d_state)
52
+ output2, new_state = ssm(x, state=state, return_state=True)
53
+ assert output2.shape == x.shape
54
+ assert new_state.shape == (batch_size, ssm.d_inner, d_state)
55
+
56
+ # Single step
57
+ x_step = torch.randn(batch_size, d_model)
58
+ output_step, state_step = ssm.step(x_step, state)
59
+ assert output_step.shape == (batch_size, d_model)
60
+ assert state_step.shape == (batch_size, ssm.d_inner, d_state)
61
+
62
+ print("✓ SSM layer test passed")
63
+
64
+
65
+ def test_attention_layer():
66
+ """Test VortexLocalAttention."""
67
+ from models.attention_layer import VortexLocalAttention
68
+
69
+ batch_size = 2
70
+ seq_len = 128
71
+ d_model = 512
72
+ num_heads = 8
73
+
74
+ attn = VortexLocalAttention(d_model, num_heads, window_size=64, use_flash_attention=False)
75
+ x = torch.randn(batch_size, seq_len, d_model)
76
+
77
+ # Forward pass
78
+ output = attn(x)
79
+ assert output.shape == x.shape
80
+
81
+ # With global mask
82
+ global_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
83
+ global_mask[0, 0] = True
84
+ output2 = attn(x, global_mask=global_mask)
85
+ assert output2.shape == x.shape
86
+
87
+ print("✓ Local attention test passed")
88
+
89
+
90
+ def test_scigate_ffn():
91
+ """Test SciGateFFN."""
92
+ from models.scigate_ffn import SciGateFFN
93
+
94
+ batch_size = 2
95
+ seq_len = 64
96
+ d_model = 512
97
+ num_domains = 7
98
+
99
+ ffn = SciGateFFN(d_model, expansion=4, num_domains=num_domains)
100
+ x = torch.randn(batch_size, seq_len, d_model)
101
+
102
+ # Without domain info
103
+ output = ffn(x)
104
+ assert output.shape == x.shape
105
+
106
+ # With domain IDs
107
+ domain_ids = torch.randint(0, num_domains, (batch_size,))
108
+ output2 = ffn(x, domain_ids=domain_ids)
109
+ assert output2.shape == x.shape
110
+
111
+ # With domain tags
112
+ domain_tags = torch.zeros(batch_size, seq_len, num_domains)
113
+ domain_tags[:, :, 0] = 1.0
114
+ output3 = ffn(x, domain_tags=domain_tags)
115
+ assert output3.shape == x.shape
116
+
117
+ print("✓ SciGate FFN test passed")
118
+
119
+
120
+ def test_equation_module():
121
+ """Test EquationModule."""
122
+ from models.science_modules.equation_module import EquationModule
123
+
124
+ d_model = 512
125
+ batch_size = 2
126
+ seq_len = 64
127
+
128
+ module = EquationModule(d_model)
129
+ x = torch.randn(batch_size, seq_len, d_model)
130
+ text = ["E = mc^2 is famous.", "The integral $\\int x dx = x^2/2$."]
131
+
132
+ output = module(x, text=text)
133
+ assert output.shape == x.shape
134
+
135
+ # Test equation loss
136
+ equation_mask = torch.zeros(batch_size, seq_len)
137
+ equation_mask[0, 5:10] = 1.0
138
+ loss = module.compute_equation_loss(x, equation_mask)
139
+ assert loss.item() >= 0
140
+
141
+ print("✓ Equation module test passed")
142
+
143
+
144
+ def test_numerical_module():
145
+ """Test NumericalReasoningModule."""
146
+ from models.science_modules.numerical_module import NumericalReasoningModule
147
+
148
+ d_model = 512
149
+ batch_size = 2
150
+ seq_len = 64
151
+
152
+ module = NumericalReasoningModule(d_model)
153
+ x = torch.randn(batch_size, seq_len, d_model)
154
+ text = ["Speed of light: 2.998e8 m/s", "6.022e23 is Avogadro's number."]
155
+
156
+ output = module(x, text=text)
157
+ assert output.shape == x.shape
158
+
159
+ print("✓ Numerical reasoning module test passed")
160
+
161
+
162
+ def test_citation_module():
163
+ """Test CitationModule."""
164
+ from models.science_modules.citation_module import CitationModule
165
+
166
+ d_model = 512
167
+ batch_size = 2
168
+ seq_len = 64
169
+
170
+ module = CitationModule(d_model)
171
+ x = torch.randn(batch_size, seq_len, d_model)
172
+ text = ["(Einstein, 1905) changed physics.", "See also [1, 2] for details."]
173
+
174
+ output, confidence = module(x, text=text)
175
+ assert output.shape == x.shape
176
+ assert confidence.shape == (batch_size, seq_len, 1)
177
+
178
+ # Test loss
179
+ citation_mask = torch.zeros(batch_size, seq_len)
180
+ citation_mask[0, 0:5] = 1.0
181
+ loss = module.compute_citation_loss(x, citation_mask, confidence)
182
+ assert loss.item() >= 0
183
+
184
+ print("✓ Citation module test passed")
185
+
186
+
187
+ def test_molecular_module():
188
+ """Test MolecularModule."""
189
+ from models.science_modules.molecular_module import MolecularModule
190
+
191
+ d_model = 512
192
+ batch_size = 2
193
+ seq_len = 64
194
+
195
+ module = MolecularModule(d_model)
196
+ x = torch.randn(batch_size, seq_len, d_model)
197
+ text = ["H2O is water.", "DNA sequence: ACGTACGT"]
198
+
199
+ output = module(x, text=text)
200
+ assert output.shape == x.shape
201
+
202
+ print("✓ Molecular module test passed")
203
+
204
+
205
+ def test_vortex_model():
206
+ """Test full VortexModel."""
207
+ from models.vortex_model import VortexModel
208
+ from configs.vortex_7b_config import VORTEX_7B_CONFIG
209
+
210
+ # Small config for testing
211
+ config = VORTEX_7B_CONFIG.copy()
212
+ config["d_model"] = 256
213
+ config["num_layers"] = 4
214
+ config["num_heads"] = 4
215
+ config["vocab_size"] = 1000
216
+
217
+ model = VortexModel(config)
218
+
219
+ batch_size = 2
220
+ seq_len = 32
221
+ input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
222
+
223
+ # Forward pass
224
+ output = model(input_ids)
225
+ logits = output["logits"]
226
+ assert logits.shape == (batch_size, seq_len, config["vocab_size"])
227
+
228
+ # Count parameters
229
+ num_params = model.get_num_params()
230
+ assert num_params > 0
231
+
232
+ print(f"✓ VortexModel test passed (params: {num_params:,})")
233
+
234
+
235
+ def test_quality_filter():
236
+ """Test ScienceQualityFilter."""
237
+ from data.quality_filter import ScienceQualityFilter
238
+
239
+ filter = ScienceQualityFilter()
240
+
241
+ # Good text
242
+ good_text = """
243
+ The experiment collected data from 100 participants. Results show a
244
+ significant effect (p < 0.05). The equation E = mc^2 is fundamental.
245
+ According to Smith et al., this confirms the hypothesis.
246
+ """
247
+ assert filter.filter(good_text)
248
+
249
+ # Bad: too short
250
+ assert not filter.filter("Too short.")
251
+
252
+ # Bad: unmatched equations
253
+ bad_eq = "Equation $E = mc^2 and another $F = ma."
254
+ assert not filter.filter(bad_eq)
255
+
256
+ print("✓ Quality filter test passed")
257
+
258
+
259
+ def test_domain_classifier():
260
+ """Test DomainClassifier."""
261
+ from data.domain_classifier import DomainClassifier
262
+
263
+ d_model = 256
264
+ classifier = DomainClassifier(d_model)
265
+
266
+ # Test with random hidden states
267
+ batch_size = 4
268
+ seq_len = 32
269
+ hidden = torch.randn(batch_size, seq_len, d_model)
270
+ logits = classifier(hidden)
271
+ assert logits.shape == (batch_size, 7)
272
+
273
+ # Test text classification
274
+ text = "Quantum mechanics describes particle behavior."
275
+ domain, conf = classifier.classify_text(text)
276
+ assert domain in range(7)
277
+ assert 0 <= conf <= 1
278
+
279
+ print("✓ Domain classifier test passed")
280
+
281
+
282
+ def test_deduplication():
283
+ """Test MinHashLSH."""
284
+ from data.deduplication import MinHashLSH
285
+
286
+ lsh = MinHashLSH(num_permutations=32, threshold=0.7, bands=4, rows_per_band=8)
287
+
288
+ docs = [
289
+ ("doc1", "The quick brown fox jumps over the lazy dog."),
290
+ ("doc2", "The quick brown fox jumps over the lazy dog!!!"),
291
+ ("doc3", "Completely different text about science."),
292
+ ]
293
+
294
+ for doc_id, text in docs:
295
+ lsh.add_document(doc_id, text)
296
+
297
+ # Query similar
298
+ results = lsh.query(docs[0][1])
299
+ # Should find doc2 as similar
300
+ assert len(results) >= 1
301
+ assert any(r[0] == "doc2" for r in results)
302
+
303
+ print("✓ Deduplication test passed")
304
+
305
+
306
+ def test_losses():
307
+ """Test VortexLoss."""
308
+ from training.losses import VortexLoss
309
+
310
+ config = {"loss_weights": {
311
+ "lm_loss": 1.0,
312
+ "equation_loss": 0.3,
313
+ "domain_loss": 0.1,
314
+ "citation_loss": 0.1,
315
+ "numerical_loss": 0.2,
316
+ }}
317
+
318
+ loss_fn = VortexLoss(config)
319
+
320
+ batch_size = 2
321
+ seq_len = 32
322
+ vocab_size = 1000
323
+
324
+ logits = torch.randn(batch_size, seq_len, vocab_size)
325
+ labels = torch.randint(0, vocab_size, (batch_size, seq_len))
326
+
327
+ losses = loss_fn(logits, labels)
328
+ assert "total_loss" in losses
329
+ assert "lm_loss" in losses
330
+ assert losses["total_loss"].item() > 0
331
+
332
+ print("✓ Losses test passed")
333
+
334
+
335
+ def test_curriculum():
336
+ """Test CurriculumScheduler."""
337
+ from training.curriculum import CurriculumScheduler
338
+
339
+ config = {
340
+ "curriculum_stages": [
341
+ {"name": "foundation", "start": 0.0, "end": 0.2},
342
+ {"name": "domain", "start": 0.2, "end": 0.5},
343
+ {"name": "reasoning", "start": 0.5, "end": 0.8},
344
+ {"name": "integration", "start": 0.8, "end": 1.0},
345
+ ]
346
+ }
347
+
348
+ total_steps = 1000
349
+ scheduler = CurriculumScheduler(config, total_steps)
350
+
351
+ # Test stage at different steps
352
+ assert scheduler.get_stage_name(0) == "foundation"
353
+ assert scheduler.get_stage_name(250) == "domain"
354
+ assert scheduler.get_stage_name(500) == "reasoning"
355
+ assert scheduler.get_stage_name(800) == "integration"
356
+
357
+ # Test sampler
358
+ weights = scheduler.get_dataset_sampler(100)
359
+ assert isinstance(weights, dict)
360
+ assert sum(weights.values()) == 1.0
361
+
362
+ print("✓ Curriculum test passed")
363
+
364
+
365
+ def test_hf_integration():
366
+ """Test HuggingFace integration."""
367
+ from configuration_vortex import VortexConfig
368
+ from modeling_vortex import VortexForCausalLM
369
+ from tokenization_vortex import VortexTokenizer
370
+
371
+ # Config
372
+ config = VortexConfig(
373
+ d_model=128,
374
+ num_layers=2,
375
+ num_heads=4,
376
+ vocab_size=100,
377
+ )
378
+
379
+ # Model
380
+ model = VortexForCausalLM(config)
381
+ batch_size = 2
382
+ seq_len = 16
383
+ input_ids = torch.randint(0, 100, (batch_size, seq_len))
384
+
385
+ outputs = model(input_ids)
386
+ assert outputs.logits.shape == (batch_size, seq_len, 100)
387
+
388
+ # Save and load
389
+ model.save_pretrained("./test_hf_model")
390
+ config.save_pretrained("./test_hf_model")
391
+
392
+ from transformers import AutoConfig, AutoModelForCausalLM
393
+ loaded_config = AutoConfig.from_pretrained("./test_hf_model")
394
+ loaded_model = AutoModelForCausalLM.from_pretrained("./test_hf_model")
395
+
396
+ assert loaded_config.model_type == "vortex"
397
+ assert isinstance(loaded_model, VortexForCausalLM)
398
+
399
+ # Cleanup
400
+ import shutil
401
+ shutil.rmtree("./test_hf_model")
402
+
403
+ print("✓ HuggingFace integration test passed")
404
+
405
+
406
+ def run_all_tests():
407
+ """Run all tests."""
408
+ tests = [
409
+ test_tokenizer,
410
+ test_ssm_layer,
411
+ test_attention_layer,
412
+ test_scigate_ffn,
413
+ test_equation_module,
414
+ test_numerical_module,
415
+ test_citation_module,
416
+ test_molecular_module,
417
+ test_vortex_model,
418
+ test_quality_filter,
419
+ test_domain_classifier,
420
+ test_deduplication,
421
+ test_losses,
422
+ test_curriculum,
423
+ test_hf_integration,
424
+ ]
425
+
426
+ print("Running Vortex unit tests...\n")
427
+ passed = 0
428
+ failed = 0
429
+
430
+ for test in tests:
431
+ try:
432
+ test()
433
+ passed += 1
434
+ except Exception as e:
435
+ print(f"✗ {test.__name__} failed: {e}")
436
+ failed += 1
437
+ import traceback
438
+ traceback.print_exc()
439
+
440
+ print(f"\n{'='*50}")
441
+ print(f"Tests: {passed + failed} total, {passed} passed, {failed} failed")
442
+ print(f"{'='*50}")
443
+
444
+ return failed == 0
445
+
446
+
447
+ if __name__ == "__main__":
448
+ success = run_all_tests()
449
+ sys.exit(0 if success else 1)
tokenization_vortex.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vortex tokenizer for HuggingFace.
3
+ Wraps VortexScienceTokenizer for HF compatibility.
4
+ """
5
+
6
+ from typing import List, Optional, Dict, Any
7
+ import json
8
+ import os
9
+
10
+
11
+ class VortexTokenizer:
12
+ """
13
+ HuggingFace-compatible tokenizer for Vortex.
14
+ Wraps VortexScienceTokenizer.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ tokenizer_file: Optional[str] = None,
20
+ config: Optional[Dict] = None,
21
+ **kwargs,
22
+ ):
23
+ """
24
+ Initialize tokenizer.
25
+
26
+ Args:
27
+ tokenizer_file: Path to tokenizer JSON
28
+ config: Tokenizer configuration
29
+ """
30
+ from .tokenizer.vortex_tokenizer import VortexScienceTokenizer
31
+
32
+ self.config = config or {}
33
+ self.special_tokens = self.config.get("special_tokens", {})
34
+
35
+ if tokenizer_file and os.path.exists(tokenizer_file):
36
+ self.tokenizer = VortexScienceTokenizer(
37
+ self.config,
38
+ tokenizer_path=tokenizer_file,
39
+ )
40
+ else:
41
+ # Initialize empty - needs training
42
+ self.tokenizer = VortexScienceTokenizer(self.config)
43
+
44
+ # HF compatibility attributes
45
+ self.pad_token = "[PAD]"
46
+ self.unk_token = "[UNK]"
47
+ self.bos_token = "[BOS]"
48
+ self.eos_token = "[EOS]"
49
+ self.pad_token_id = self.special_tokens.get("[PAD]", 0)
50
+ self.unk_token_id = self.special_tokens.get("[UNK]", 1)
51
+ self.bos_token_id = self.special_tokens.get("[BOS]", 2)
52
+ self.eos_token_id = self.special_tokens.get("[EOS]", 3)
53
+
54
+ @classmethod
55
+ def from_pretrained(
56
+ cls,
57
+ pretrained_model_name_or_path: str,
58
+ **kwargs,
59
+ ):
60
+ """Load tokenizer from pretrained model."""
61
+ tokenizer_path = os.path.join(pretrained_model_name_or_path, "vortex_tokenizer.json")
62
+ config_path = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
63
+
64
+ config = {}
65
+ if os.path.exists(config_path):
66
+ with open(config_path, "r") as f:
67
+ config = json.load(f)
68
+
69
+ return cls(tokenizer_file=tokenizer_path, config=config, **kwargs)
70
+
71
+ def __call__(
72
+ self,
73
+ text: str | List[str],
74
+ padding: bool = False,
75
+ truncation: bool = False,
76
+ max_length: Optional[int] = None,
77
+ return_tensors: str = "pt",
78
+ **kwargs,
79
+ ) -> Dict[str, Any]:
80
+ """
81
+ Tokenize text.
82
+
83
+ Args:
84
+ text: Input text or list of texts
85
+ padding: Pad to same length
86
+ truncation: Truncate to max_length
87
+ max_length: Maximum length
88
+ return_tensors: "pt" for PyTorch, "np" for numpy, None for list
89
+
90
+ Returns:
91
+ Dictionary with input_ids, attention_mask
92
+ """
93
+ if isinstance(text, str):
94
+ text = [text]
95
+
96
+ if max_length is None:
97
+ max_length = self.config.get("max_seq_len", 16384)
98
+
99
+ # Use batch_encode
100
+ result = self.tokenizer.batch_encode(
101
+ text,
102
+ padding=padding,
103
+ truncation=truncation,
104
+ max_length=max_length,
105
+ return_tensors=return_tensors,
106
+ )
107
+
108
+ return result
109
+
110
+ def encode(
111
+ self,
112
+ text: str,
113
+ add_special_tokens: bool = True,
114
+ **kwargs,
115
+ ) -> List[int]:
116
+ """Encode text to token IDs."""
117
+ result = self.tokenizer.encode(
118
+ text,
119
+ add_special_tokens=add_special_tokens,
120
+ return_tensors=None,
121
+ )
122
+ return result["input_ids"]
123
+
124
+ def decode(
125
+ self,
126
+ token_ids: List[int],
127
+ skip_special_tokens: bool = True,
128
+ **kwargs,
129
+ ) -> str:
130
+ """Decode token IDs to text."""
131
+ return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
132
+
133
+ def save_pretrained(self, save_directory: str):
134
+ """Save tokenizer to directory."""
135
+ os.makedirs(save_directory, exist_ok=True)
136
+ tokenizer_path = os.path.join(save_directory, "vortex_tokenizer.json")
137
+ self.tokenizer.save(tokenizer_path)
138
+
139
+ # Save tokenizer config
140
+ config_path = os.path.join(save_directory, "tokenizer_config.json")
141
+ with open(config_path, "w") as f:
142
+ json.dump({
143
+ "model_type": "vortex",
144
+ "special_tokens": self.special_tokens,
145
+ }, f, indent=2)
146
+
147
+ @property
148
+ def vocab_size(self) -> int:
149
+ """Get vocabulary size."""
150
+ return self.tokenizer.vocab_size
151
+
152
+ def get_vocab(self) -> Dict[str, int]:
153
+ """Get vocabulary dictionary."""
154
+ return self.tokenizer.get_vocab()
155
+
156
+
157
+ def test_vortex_tokenizer():
158
+ """Test VortexTokenizer."""
159
+ from configs.vortex_7b_config import VORTEX_7B_CONFIG
160
+
161
+ tokenizer = VortexTokenizer(config=VORTEX_7B_CONFIG)
162
+
163
+ text = "The equation is $E = mc^2$ and the reaction is H2O."
164
+ encoded = tokenizer(text, padding=False, truncation=True, max_length=128)
165
+ print(f"Encoded: {encoded['input_ids'][0][:10]}...")
166
+
167
+ decoded = tokenizer.decode(encoded["input_ids"][0])
168
+ print(f"Decoded: {decoded[:50]}...")
169
+
170
+ print("VortexTokenizer test passed!")
171
+
172
+
173
+ if __name__ == "__main__":
174
+ test_vortex_tokenizer()
tokenizer/__pycache__/vortex_tokenizer.cpython-313.pyc ADDED
Binary file (17.4 kB). View file
 
tokenizer/vortex_tokenizer.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VortexScienceTokenizer: A custom BPE tokenizer optimized for scientific text.
3
+ Trains on science corpus and extends vocabulary with domain-specific tokens.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import re
9
+ from pathlib import Path
10
+ from typing import List, Dict, Optional, Tuple, Union
11
+ import torch
12
+
13
+ try:
14
+ from tokenizers import Tokenizer, models, pre_tokenizers, processors, trainers
15
+ from tokenizers.normalizers import Lowercase, NFD, StripAccents
16
+ except ImportError:
17
+ print("Please install tokenizers: pip install tokenizers")
18
+ raise
19
+
20
+
21
+ class VortexScienceTokenizer:
22
+ """
23
+ Science-optimized BPE tokenizer with domain extensions.
24
+
25
+ Features:
26
+ - Base BPE vocabulary (40,000 tokens) trained on scientific corpus
27
+ - Extended science vocabulary (10,000 tokens) for LaTeX, chemistry, units, etc.
28
+ - Special tokens for equation/citation/molecule spans
29
+ - Domain tags for science areas
30
+ - Digit-level number handling (optional, can be toggled)
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ config: Dict,
36
+ tokenizer_path: Optional[str] = None,
37
+ vocab_size: int = 50000,
38
+ base_vocab_size: int = 40000,
39
+ extension_vocab_size: int = 10000,
40
+ ):
41
+ """
42
+ Initialize the tokenizer.
43
+
44
+ Args:
45
+ config: Model configuration with special tokens
46
+ tokenizer_path: Path to pre-trained tokenizer (if loading)
47
+ vocab_size: Total vocabulary size
48
+ base_vocab_size: Size of base BPE vocabulary
49
+ extension_vocab_size: Size of science extension vocabulary
50
+ """
51
+ self.config = config
52
+ self.base_vocab_size = base_vocab_size
53
+ self.extension_vocab_size = extension_vocab_size
54
+ self._vocab_size = vocab_size
55
+
56
+ self.special_tokens = config.get("special_tokens", {})
57
+ self.domain_tags = config.get("domain_tags", [])
58
+
59
+ if tokenizer_path and os.path.exists(tokenizer_path):
60
+ self.tokenizer = Tokenizer.from_file(tokenizer_path)
61
+ print(f"Loaded tokenizer from {tokenizer_path}")
62
+ else:
63
+ # Initialize empty BPE tokenizer
64
+ self.tokenizer = Tokenizer(models.BPE())
65
+ self._setup_pre_tokenizer()
66
+ print("Initialized empty BPE tokenizer")
67
+
68
+ def _setup_pre_tokenizer(self):
69
+ """Configure pre-tokenization rules."""
70
+ # Use byte-level pre-tokenization for robustness
71
+ self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
72
+ self.tokenizer.normalizer = None # Keep original casing for science terms
73
+
74
+ def train(
75
+ self,
76
+ file_paths: List[str],
77
+ min_frequency: int = 2,
78
+ special_tokens: Optional[List[str]] = None,
79
+ ):
80
+ """
81
+ Train the BPE tokenizer on scientific text files.
82
+
83
+ Args:
84
+ file_paths: List of text file paths for training
85
+ min_frequency: Minimum token frequency to keep
86
+ special_tokens: Additional special tokens to add
87
+ """
88
+ if special_tokens is None:
89
+ special_tokens = list(self.special_tokens.keys()) + self.domain_tags
90
+
91
+ print(f"Training tokenizer on {len(file_paths)} files...")
92
+ print(f"Base vocab size: {self.base_vocab_size}")
93
+ print(f"Special tokens: {special_tokens}")
94
+
95
+ trainer = trainers.BpeTrainer(
96
+ vocab_size=self.base_vocab_size,
97
+ min_frequency=min_frequency,
98
+ special_tokens=special_tokens,
99
+ show_progress=True,
100
+ )
101
+
102
+ self.tokenizer.train(file_paths, trainer=trainer)
103
+ print(f"Training complete. Vocabulary size: {self.tokenizer.get_vocab_size()}")
104
+
105
+ # Extend with science-specific tokens
106
+ self._extend_science_vocabulary()
107
+
108
+ def _extend_science_vocabulary(self):
109
+ """Add science-specific tokens to the vocabulary."""
110
+ current_vocab = self.tokenizer.get_vocab()
111
+ new_tokens = []
112
+
113
+ # LaTeX math symbols (common ones)
114
+ latex_symbols = [
115
+ "\\alpha", "\\beta", "\\gamma", "\\delta", "\\epsilon", "\\zeta",
116
+ "\\eta", "\\theta", "\\iota", "\\kappa", "\\lambda", "\\mu",
117
+ "\\nu", "\\xi", "\\pi", "\\rho", "\\sigma", "\\tau",
118
+ "\\upsilon", "\\phi", "\\chi", "\\psi", "\\omega",
119
+ "\\Gamma", "\\Delta", "\\Theta", "\\Lambda", "\\Xi", "\\Pi",
120
+ "\\Sigma", "\\Phi", "\\Psi", "\\Omega",
121
+ "\\sum", "\\prod", "\\int", "\\partial", "\\nabla", "\\infty",
122
+ "\\leq", "\\geq", "\\neq", "\\approx", "\\equiv", "\\sim",
123
+ "\\in", "\\notin", "\\subset", "\\supset", "\\cup", "\\cap",
124
+ "\\forall", "\\exists", "\\neg", "\\land", "\\lor", "\\rightarrow",
125
+ "\\leftarrow", "\\Rightarrow", "\\Leftarrow", "\\leftrightarrow",
126
+ "\\frac", "\\sqrt", "\\binom", "\\begin", "\\end", "\\mathbf",
127
+ "\\mathcal", "\\mathrm", "\\mathbb", "\\mathfrak",
128
+ ]
129
+ new_tokens.extend(latex_symbols)
130
+
131
+ # Greek letters (Unicode)
132
+ greek_letters = [
133
+ "α", "β", "γ", "δ", "ε", "ζ", "η", "θ", "ι", "κ", "λ", "μ",
134
+ "ν", "ξ", "ο", "π", "ρ", "σ", "τ", "υ", "φ", "χ", "ψ", "ω",
135
+ "Γ", "Δ", "Θ", "Λ", "Ξ", "Π", "Σ", "Φ", "Ψ", "Ω",
136
+ ]
137
+ new_tokens.extend(greek_letters)
138
+
139
+ # SI units and derived units
140
+ si_units = [
141
+ "m", "kg", "s", "mol", "K", "A", "cd", "mol",
142
+ "Hz", "N", "Pa", "J", "W", "C", "V", "F", "Ω", "S",
143
+ "Wb", "T", "H", "lm", "lx", "Bq", "Gy", "Sv", "kat",
144
+ "eV", "u", "Da", "Å", "°C", "%", "‰",
145
+ "M", "mM", "μM", "nM", "pM",
146
+ "g", "mg", "μg", "ng", "pg",
147
+ "km", "m", "cm", "mm", "μm", "nm", "pm",
148
+ "L", "mL", "μL", "nL",
149
+ "h", "min", "s", "ms", "μs", "ns",
150
+ ]
151
+ new_tokens.extend(si_units)
152
+
153
+ # Common scientific abbreviations
154
+ sci_abbrevs = [
155
+ "DNA", "RNA", "mRNA", "tRNA", "rRNA", "cDNA", "gDNA",
156
+ "ATP", "ADP", "AMP", "NAD", "NADP", "FAD", "CoA",
157
+ "pH", "pKa", "pKb", "pI",
158
+ "PCR", "RT", "qPCR", "NGS", "WGS",
159
+ "IC50", "EC50", "KD", "Ki",
160
+ "XRD", "NMR", "IR", "UV", "VIS", "MS", "GC", "HPLC",
161
+ "SEM", "TEM", "AFM", "STM",
162
+ "S/N", "SNR", "RMS", "Std", "Var", "Cov",
163
+ "et al.", "vs.", "cf.", "viz.",
164
+ "Fig", "Eq", "Ref", "Tab", "Suppl",
165
+ ]
166
+ new_tokens.extend(sci_abbrevs)
167
+
168
+ # Chemical element symbols
169
+ elements = [
170
+ "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne",
171
+ "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar",
172
+ "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn",
173
+ "Ga", "Ge", "As", "Se", "Br", "Kr",
174
+ "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd",
175
+ "In", "Sn", "Sb", "Te", "I", "Xe",
176
+ "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb",
177
+ "Dy", "Ho", "Er", "Tm", "Yb", "Lu",
178
+ "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb",
179
+ "Bi", "Po", "At", "Rn",
180
+ "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk",
181
+ "Cf", "Es", "Fm", "Md", "No", "Lr",
182
+ "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh",
183
+ "Fl", "Mc", "Lv", "Ts", "Og",
184
+ ]
185
+ new_tokens.extend(elements)
186
+
187
+ # Amino acid single-letter codes
188
+ amino_acids = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I",
189
+ "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"]
190
+ new_tokens.extend(amino_acids)
191
+
192
+ # Mathematical operators (Unicode)
193
+ math_ops = [
194
+ "±", "∓", "×", "÷", "∈", "∉", "∋", "∏", "∑", "∧", "∨", "¬",
195
+ "≤", "≥", "≠", "≈", "≡", "≅", "≆", "≇", "≉", "≊", "≋",
196
+ "⊂", "⊃", "⊆", "⊇", "⊄", "⊅", "⊈", "⊉",
197
+ "∞", "∂", "∇", "√", "∛", "∜",
198
+ "∫", "∬", "∭", "∮", "∯", "∰",
199
+ "∴", "∵", "∶", "∷", "∼", "∽", "≈", "≋",
200
+ "⟨", "⟩", "|", "‖", "‵", "′", "″", "‴",
201
+ "•", "·", "‣", "⁂", "※", "‼", "⁇", "⁈",
202
+ ]
203
+ new_tokens.extend(math_ops)
204
+
205
+ # Add tokens that aren't already in vocabulary
206
+ for token in new_tokens:
207
+ if token not in current_vocab:
208
+ self.tokenizer.add_tokens([token])
209
+
210
+ print(f"Extended vocabulary with {len(new_tokens)} science tokens")
211
+ print(f"Final vocabulary size: {self.tokenizer.get_vocab_size()}")
212
+
213
+ def save(self, path: str):
214
+ """Save tokenizer to disk."""
215
+ self.tokenizer.save(path)
216
+ print(f"Tokenizer saved to {path}")
217
+
218
+ def encode(
219
+ self,
220
+ text: str,
221
+ add_special_tokens: bool = True,
222
+ return_tensors: str = "pt",
223
+ ) -> Union[Dict, torch.Tensor]:
224
+ """
225
+ Encode text to token IDs.
226
+
227
+ Args:
228
+ text: Input text
229
+ add_special_tokens: Add BOS/EOS tokens
230
+ return_tensors: "pt" for PyTorch tensors, "np" for numpy, None for list
231
+
232
+ Returns:
233
+ Dictionary with input_ids and attention_mask, or tensors/list
234
+ """
235
+ encoding = self.tokenizer.encode(text, add_special_tokens=add_special_tokens)
236
+
237
+ result = {
238
+ "input_ids": encoding.ids,
239
+ "attention_mask": encoding.attention_mask,
240
+ }
241
+
242
+ if return_tensors == "pt":
243
+ result = {k: torch.tensor(v).unsqueeze(0) for k, v in result.items()}
244
+ elif return_tensors == "np":
245
+ import numpy as np
246
+ result = {k: np.array(v) for k, v in result.items()}
247
+
248
+ return result
249
+
250
+ def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
251
+ """Decode token IDs back to text."""
252
+ return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
253
+
254
+ def batch_encode(
255
+ self,
256
+ texts: List[str],
257
+ padding: bool = True,
258
+ truncation: bool = True,
259
+ max_length: Optional[int] = None,
260
+ return_tensors: str = "pt",
261
+ ) -> Dict:
262
+ """
263
+ Encode a batch of texts.
264
+
265
+ Args:
266
+ texts: List of input texts
267
+ padding: Pad to same length
268
+ truncation: Truncate to max_length
269
+ max_length: Maximum sequence length
270
+ return_tensors: Tensor format
271
+
272
+ Returns:
273
+ Batch encoded dictionary
274
+ """
275
+ if max_length is None:
276
+ max_length = self.config.get("max_seq_len", 16384)
277
+
278
+ encodings = self.tokenizer.encode_batch(
279
+ texts,
280
+ add_special_tokens=True,
281
+ )
282
+
283
+ # Manual padding/truncation
284
+ input_ids = []
285
+ attention_masks = []
286
+
287
+ for enc in encodings:
288
+ ids = enc.ids
289
+ mask = enc.attention_mask
290
+
291
+ if truncation and len(ids) > max_length:
292
+ ids = ids[:max_length]
293
+ mask = mask[:max_length]
294
+
295
+ input_ids.append(ids)
296
+ attention_masks.append(mask)
297
+
298
+ # Pad to same length if requested
299
+ if padding:
300
+ max_len = max(len(ids) for ids in input_ids)
301
+ padded_ids = []
302
+ padded_masks = []
303
+
304
+ for ids, mask in zip(input_ids, attention_masks):
305
+ pad_len = max_len - len(ids)
306
+ padded_ids.append(ids + [self.special_tokens["[PAD]"]] * pad_len)
307
+ padded_masks.append(mask + [0] * pad_len)
308
+
309
+ input_ids = padded_ids
310
+ attention_masks = padded_masks
311
+
312
+ result = {
313
+ "input_ids": input_ids,
314
+ "attention_mask": attention_masks,
315
+ }
316
+
317
+ if return_tensors == "pt":
318
+ result = {k: torch.tensor(v) for k, v in result.items()}
319
+
320
+ return result
321
+
322
+ @property
323
+ def vocab_size(self) -> int:
324
+ """Get vocabulary size."""
325
+ return self.tokenizer.get_vocab_size()
326
+
327
+ def get_vocab(self) -> Dict[str, int]:
328
+ """Get vocabulary dictionary."""
329
+ return self.tokenizer.get_vocab()
330
+
331
+ def token_to_id(self, token: str) -> int:
332
+ """Convert token to ID."""
333
+ return self.tokenizer.token_to_id(token)
334
+
335
+ def id_to_token(self, id: int) -> str:
336
+ """Convert ID to token."""
337
+ return self.tokenizer.id_to_token(id)
338
+
339
+
340
+ def build_science_vocabulary_file(output_path: str):
341
+ """
342
+ Build a science vocabulary text file for BPE training.
343
+ This file contains seed vocabulary terms to ensure science tokens are present.
344
+ """
345
+ science_terms = []
346
+
347
+ # LaTeX commands
348
+ latex_terms = [
349
+ "\\alpha", "\\beta", "\\gamma", "\\delta", "\\epsilon", "\\zeta",
350
+ "\\eta", "\\theta", "\\iota", "\\kappa", "\\lambda", "\\mu",
351
+ "\\nu", "\\xi", "\\pi", "\\rho", "\\sigma", "\\tau",
352
+ "\\upsilon", "\\phi", "\\chi", "\\psi", "\\omega",
353
+ "\\sum", "\\prod", "\\int", "\\partial", "\\nabla", "\\infty",
354
+ "\\frac", "\\sqrt", "\\binom", "\\begin", "\\end",
355
+ "\\mathbf", "\\mathcal", "\\mathrm", "\\mathbb",
356
+ "\\in", "\\subset", "\\cup", "\\cap", "\\forall", "\\exists",
357
+ "\\rightarrow", "\\leftarrow", "\\Rightarrow", "\\Leftarrow",
358
+ "\\leq", "\\geq", "\\neq", "\\approx", "\\equiv",
359
+ ]
360
+ science_terms.extend(latex_terms)
361
+
362
+ # Chemical formulas
363
+ chem_formulas = [
364
+ "H2O", "CO2", "O2", "N2", "H2", "CH4", "C2H6", "C3H8",
365
+ "C6H12O6", "C12H22O11", "HCl", "H2SO4", "HNO3", "H3PO4",
366
+ "NaOH", "KOH", "CaCO3", "NaCl", "KCl", "MgCl2",
367
+ "Fe2O3", "Fe3O4", "CuO", "Cu2O", "ZnO", "Al2O3",
368
+ "SiO2", "TiO2", "MnO2", "NH3", "NO", "NO2", "N2O",
369
+ "SO2", "SO3", "CO", "CH3COOH", "C2H5OH",
370
+ ]
371
+ science_terms.extend(chem_formulas)
372
+
373
+ # Mathematical expressions
374
+ math_exprs = [
375
+ "x^2", "x^3", "e^x", "ln(x)", "log(x)", "sin(x)", "cos(x)",
376
+ "tan(x)", "arcsin(x)", "arccos(x)", "arctan(x)",
377
+ "f(x)", "g(x)", "h(x)", "F(x)", "G(x)",
378
+ "dx", "dy", "dz", "dt", "∂x", "∂y", "∂z",
379
+ "∫", "∬", "∭", "∮", "∑_{i=1}^{n}", "∏_{i=1}^{n}",
380
+ ]
381
+ science_terms.extend(math_exprs)
382
+
383
+ # Units with numbers
384
+ unit_exprs = [
385
+ "10^6", "10^9", "10^12", "10^15", "10^18",
386
+ "10^-3", "10^-6", "10^-9", "10^-12",
387
+ "m/s", "km/h", "cm/s", "mm/s",
388
+ "J/mol", "kJ/mol", "cal", "kcal",
389
+ "eV", "MeV", "GeV", "TeV",
390
+ "Hz", "kHz", "MHz", "GHz",
391
+ "Pa", "kPa", "MPa", "GPa",
392
+ "°C", "K", "°F",
393
+ ]
394
+ science_terms.extend(unit_exprs)
395
+
396
+ # Write to file
397
+ with open(output_path, "w", encoding="utf-8") as f:
398
+ for term in science_terms:
399
+ f.write(term + "\n")
400
+
401
+ print(f"Science vocabulary seed file written to {output_path}")
402
+ print(f"Total seed terms: {len(science_terms)}")
403
+
404
+
405
+ if __name__ == "__main__":
406
+ # Example usage
407
+ import sys
408
+
409
+ if len(sys.argv) < 2:
410
+ print("Usage: python vortex_tokenizer.py <train_data.txt> [output_dir]")
411
+ sys.exit(1)
412
+
413
+ train_data = sys.argv[1]
414
+ output_dir = sys.argv[2] if len(sys.argv) > 2 else "."
415
+
416
+ # Load config (simplified for standalone)
417
+ config = {
418
+ "special_tokens": {
419
+ "[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3,
420
+ "[EQUATION]": 4, "[/EQUATION]": 5,
421
+ "[CITATION]": 6, "[/CITATION]": 7,
422
+ "[MOLECULE]": 8, "[/MOLECULE]": 9,
423
+ "[FIGURE]": 10, "[TABLE]": 11,
424
+ "[MATH]": 12, "[CHEM]": 13, "[BIO]": 14,
425
+ "[PHYS]": 15, "[EARTH]": 16, "[SPACE]": 17, "[ZOO]": 18,
426
+ },
427
+ "domain_tags": ["[MATH]", "[CHEM]", "[BIO]", "[PHYS]", "[EARTH]", "[SPACE]", "[ZOO]"],
428
+ "max_seq_len": 16384,
429
+ }
430
+
431
+ # Build seed vocabulary
432
+ seed_vocab_path = os.path.join(output_dir, "science_seed_vocab.txt")
433
+ build_science_vocabulary_file(seed_vocab_path)
434
+
435
+ # Initialize and train tokenizer
436
+ tokenizer = VortexScienceTokenizer(config)
437
+ tokenizer.train([train_data])
438
+
439
+ # Save tokenizer
440
+ tokenizer_path = os.path.join(output_dir, "vortex_tokenizer.json")
441
+ tokenizer.save(tokenizer_path)
442
+ print(f"Tokenizer saved to {tokenizer_path}")
train.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Main training entry point for Vortex models.
4
+ """
5
+
6
+ import argparse
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ import torch
11
+
12
+ from configs.vortex_7b_config import VORTEX_7B_CONFIG
13
+ from configs.vortex_13b_config import VORTEX_13B_CONFIG
14
+ from configs.training_config import TRAINING_CONFIG, TRAINING_CONFIG_7B_CUDA, TRAINING_CONFIG_13B_CUDA, TRAINING_CONFIG_MPS
15
+
16
+ from models.vortex_model import VortexModel
17
+ from tokenizer.vortex_tokenizer import VortexScienceTokenizer
18
+ from training.trainer import VortexTrainer, VortexDataset
19
+
20
+
21
+ def parse_args():
22
+ parser = argparse.ArgumentParser(description="Train Vortex scientific language model")
23
+ parser.add_argument("--model_size", type=str, choices=["7b", "13b"], default="7b",
24
+ help="Model size to train")
25
+ parser.add_argument("--device", type=str, default="cuda",
26
+ choices=["cuda", "mps", "cpu"],
27
+ help="Device to train on")
28
+ parser.add_argument("--use_mps", action="store_true",
29
+ help="Use MPS backend (Apple Silicon)")
30
+ parser.add_argument("--data_dir", type=str, default="./data/processed",
31
+ help="Directory with processed data shards")
32
+ parser.add_argument("--tokenizer_path", type=str, default=None,
33
+ help="Path to pretrained tokenizer")
34
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None,
35
+ help="Resume training from checkpoint")
36
+ parser.add_argument("--output_dir", type=str, default="./checkpoints",
37
+ help="Output directory for checkpoints")
38
+ parser.add_argument("--max_steps", type=int, default=None,
39
+ help="Override max training steps")
40
+ parser.add_argument("--micro_batch_size", type=int, default=None,
41
+ help="Override micro batch size")
42
+ parser.add_argument("--quantization", type=str, choices=[None, "int8", "int4"], default=None,
43
+ help="Quantization for 13B on 8GB")
44
+ return parser.parse_args()
45
+
46
+
47
+ def main():
48
+ args = parse_args()
49
+
50
+ # Load configs
51
+ if args.model_size == "7b":
52
+ model_config = VORTEX_7B_CONFIG.copy()
53
+ train_config = TRAINING_CONFIG_7B_CUDA.copy()
54
+ else:
55
+ model_config = VORTEX_13B_CONFIG.copy()
56
+ train_config = TRAINING_CONFIG_13B_CUDA.copy()
57
+
58
+ # Override with MPS config if needed
59
+ if args.use_mps or args.device == "mps":
60
+ train_config = TRAINING_CONFIG_MPS.copy()
61
+ train_config["use_mps"] = True
62
+
63
+ # Apply overrides
64
+ if args.max_steps:
65
+ train_config["max_steps"] = args.max_steps
66
+ if args.micro_batch_size:
67
+ train_config["micro_batch_size"] = args.micro_batch_size
68
+ if args.quantization:
69
+ train_config["quantization"] = args.quantization
70
+
71
+ # Set device
72
+ device = torch.device(args.device)
73
+ train_config["device"] = args.device
74
+
75
+ print(f"Training Vortex-{args.model_size.upper()}")
76
+ print(f"Device: {device}")
77
+ print(f"Max steps: {train_config['max_steps']}")
78
+ print(f"Micro batch size: {train_config['micro_batch_size']}")
79
+
80
+ # Create tokenizer
81
+ print("Loading tokenizer...")
82
+ tokenizer = VortexScienceTokenizer(
83
+ model_config,
84
+ tokenizer_path=args.tokenizer_path,
85
+ )
86
+ print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
87
+
88
+ # Create model
89
+ print("Creating model...")
90
+ model = VortexModel(model_config)
91
+ print(f"Model parameters: {model.get_num_params():,}")
92
+
93
+ # Estimate memory
94
+ mem = model.estimate_memory_usage(
95
+ train_config["micro_batch_size"],
96
+ model_config["max_seq_len"],
97
+ )
98
+ print("Memory estimate:")
99
+ for k, v in mem.items():
100
+ print(f" {k}: {v:.2f} GB")
101
+
102
+ # Load dataset
103
+ print("Loading dataset...")
104
+ data_dir = Path(args.data_dir)
105
+ shard_files = sorted(list(data_dir.glob("train_*.parquet")))
106
+ if not shard_files:
107
+ print(f"No training shards found in {data_dir}")
108
+ print("Please run data pipeline first.")
109
+ sys.exit(1)
110
+
111
+ train_dataset = VortexDataset(
112
+ shard_files,
113
+ tokenizer,
114
+ max_seq_len=model_config["max_seq_len"],
115
+ )
116
+ print(f"Training dataset size: {len(train_dataset)} samples")
117
+
118
+ # Create eval dataset (use first few shards)
119
+ eval_shard_files = shard_files[:1] # Use first shard for eval
120
+ eval_dataset = VortexDataset(
121
+ eval_shard_files,
122
+ tokenizer,
123
+ max_seq_len=model_config["max_seq_len"],
124
+ )
125
+
126
+ # Create trainer
127
+ trainer = VortexTrainer(
128
+ model=model,
129
+ tokenizer=tokenizer,
130
+ train_dataset=train_dataset,
131
+ config=train_config,
132
+ eval_dataset=eval_dataset,
133
+ )
134
+
135
+ # Resume from checkpoint if specified
136
+ if args.resume_from_checkpoint:
137
+ trainer.load_checkpoint(args.resume_from_checkpoint)
138
+
139
+ # Train
140
+ trainer.train()
141
+
142
+ print("Training complete!")
143
+
144
+
145
+ if __name__ == "__main__":
146
+ main()
training/__pycache__/curriculum.cpython-313.pyc ADDED
Binary file (5 kB). View file
 
training/__pycache__/losses.cpython-313.pyc ADDED
Binary file (5.84 kB). View file
 
training/curriculum.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Curriculum learning for Vortex model.
3
+ Progresses through stages: Foundation → Domain → Reasoning → Integration.
4
+ """
5
+
6
+ from typing import List, Dict, Optional
7
+ import torch
8
+
9
+
10
+ class CurriculumScheduler:
11
+ """
12
+ Schedules curriculum stages during training.
13
+ Each stage has a start and end fraction of total training steps.
14
+ """
15
+
16
+ STAGES = ["foundation", "domain", "reasoning", "integration"]
17
+
18
+ def __init__(
19
+ self,
20
+ config: Dict,
21
+ total_steps: int,
22
+ ):
23
+ """
24
+ Initialize curriculum scheduler.
25
+
26
+ Args:
27
+ config: Training config with curriculum_stages
28
+ total_steps: Total number of training steps
29
+ """
30
+ self.config = config
31
+ self.total_steps = total_steps
32
+ self.stages = config.get("curriculum_stages", [
33
+ {"name": "foundation", "start": 0.0, "end": 0.2},
34
+ {"name": "domain", "start": 0.2, "end": 0.5},
35
+ {"name": "reasoning", "start": 0.5, "end": 0.8},
36
+ {"name": "integration", "start": 0.8, "end": 1.0},
37
+ ])
38
+
39
+ # Convert fractions to step numbers
40
+ for stage in self.stages:
41
+ stage["start_step"] = int(stage["start"] * total_steps)
42
+ stage["end_step"] = int(stage["end"] * total_steps)
43
+
44
+ def get_stage(
45
+ self,
46
+ current_step: int,
47
+ ) -> Optional[Dict]:
48
+ """
49
+ Get current curriculum stage.
50
+
51
+ Args:
52
+ current_step: Current training step
53
+
54
+ Returns:
55
+ Stage dictionary or None if training complete
56
+ """
57
+ for stage in self.stages:
58
+ if stage["start_step"] <= current_step < stage["end_step"]:
59
+ return stage
60
+ return None
61
+
62
+ def get_stage_name(self, current_step: int) -> str:
63
+ """Get current stage name."""
64
+ stage = self.get_stage(current_step)
65
+ return stage["name"] if stage else "complete"
66
+
67
+ def get_stage_weight(
68
+ self,
69
+ current_step: int,
70
+ base_weight: float,
71
+ ) -> float:
72
+ """
73
+ Get weight for a curriculum component based on stage.
74
+
75
+ Args:
76
+ current_step: Current training step
77
+ base_weight: Base weight for the component
78
+ Returns:
79
+ Adjusted weight (can be 0 if component not active in current stage)
80
+ """
81
+ stage = self.get_stage(current_step)
82
+ if not stage:
83
+ return 0.0
84
+
85
+ stage_name = stage["name"]
86
+
87
+ # Define which components are active in each stage
88
+ stage_components = {
89
+ "foundation": ["lm_loss"], # Only language modeling
90
+ "domain": ["lm_loss", "equation_loss", "domain_loss"],
91
+ "reasoning": ["lm_loss", "equation_loss", "domain_loss", "citation_loss"],
92
+ "integration": ["lm_loss", "equation_loss", "domain_loss", "citation_loss", "numerical_loss"],
93
+ }
94
+
95
+ active_components = stage_components.get(stage_name, ["lm_loss"])
96
+
97
+ # Return base weight if component active, else 0
98
+ # (Caller checks if their component is in active_components)
99
+ return base_weight if "lm_loss" in active_components else 0.0
100
+
101
+ def get_dataset_sampler(
102
+ self,
103
+ current_step: int,
104
+ ):
105
+ """
106
+ Get dataset sampler for current stage.
107
+ Different stages may mix datasets differently.
108
+
109
+ Returns:
110
+ Sampler weights for different datasets
111
+ """
112
+ stage = self.get_stage(current_step)
113
+ if not stage:
114
+ return None
115
+
116
+ stage_name = stage["name"]
117
+
118
+ # Dataset mixing proportions per stage
119
+ mixing_proportions = {
120
+ "foundation": {
121
+ "pile_scientific": 0.3,
122
+ "s2orc": 0.3,
123
+ "automath": 0.2,
124
+ "pubmed_qa": 0.2,
125
+ },
126
+ "domain": {
127
+ "pile_scientific": 0.2,
128
+ "s2orc": 0.2,
129
+ "automath": 0.2,
130
+ "pubmed_qa": 0.2,
131
+ "deepmind_math": 0.2,
132
+ },
133
+ "reasoning": {
134
+ "pile_scientific": 0.15,
135
+ "s2orc": 0.15,
136
+ "automath": 0.3,
137
+ "deepmind_math": 0.3,
138
+ "pubmed_qa": 0.1,
139
+ },
140
+ "integration": {
141
+ "pile_scientific": 0.2,
142
+ "s2orc": 0.2,
143
+ "automath": 0.2,
144
+ "deepmind_math": 0.2,
145
+ "pubmed_qa": 0.2,
146
+ },
147
+ }
148
+
149
+ return mixing_proportions.get(stage_name, {"pile_scientific": 1.0})
150
+
151
+
152
+ def test_curriculum():
153
+ """Test curriculum scheduler."""
154
+ config = {
155
+ "curriculum_stages": [
156
+ {"name": "foundation", "start": 0.0, "end": 0.2},
157
+ {"name": "domain", "start": 0.2, "end": 0.5},
158
+ {"name": "reasoning", "start": 0.5, "end": 0.8},
159
+ {"name": "integration", "start": 0.8, "end": 1.0},
160
+ ]
161
+ }
162
+
163
+ total_steps = 1000
164
+ scheduler = CurriculumScheduler(config, total_steps)
165
+
166
+ for step in [0, 100, 250, 500, 750, 999]:
167
+ stage = scheduler.get_stage(step)
168
+ name = scheduler.get_stage_name(step)
169
+ print(f"Step {step}: {name}")
170
+
171
+ print("Curriculum test passed!")
172
+
173
+
174
+ if __name__ == "__main__":
175
+ test_curriculum()
training/losses.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Science-aware losses for Vortex model training.
3
+ Combines standard language modeling with auxiliary tasks.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Dict, Optional, Tuple
10
+
11
+
12
+ class VortexLoss(nn.Module):
13
+ """
14
+ Combined loss for Vortex model with science-aware components.
15
+ total_loss = (
16
+ lm_loss * 1.0
17
+ + equation_loss * 0.3
18
+ + domain_loss * 0.1
19
+ + citation_loss * 0.1
20
+ + numerical_loss * 0.2
21
+ )
22
+ """
23
+
24
+ def __init__(self, config: Dict):
25
+ """
26
+ Initialize loss.
27
+
28
+ Args:
29
+ config: Training config with loss_weights
30
+ """
31
+ super().__init__()
32
+ self.loss_weights = config.get("loss_weights", {
33
+ "lm_loss": 1.0,
34
+ "equation_loss": 0.3,
35
+ "domain_loss": 0.1,
36
+ "citation_loss": 0.1,
37
+ "numerical_loss": 0.2,
38
+ })
39
+
40
+ def forward(
41
+ self,
42
+ logits: torch.Tensor,
43
+ labels: torch.Tensor,
44
+ equation_module: Optional[nn.Module] = None,
45
+ equation_mask: Optional[torch.Tensor] = None,
46
+ domain_logits: Optional[torch.Tensor] = None,
47
+ domain_labels: Optional[torch.Tensor] = None,
48
+ citation_module: Optional[nn.Module] = None,
49
+ citation_mask: Optional[torch.Tensor] = None,
50
+ citation_confidence: Optional[torch.Tensor] = None,
51
+ numerical_module: Optional[nn.Module] = None,
52
+ numerical_mask: Optional[torch.Tensor] = None,
53
+ ) -> Dict[str, torch.Tensor]:
54
+ """
55
+ Compute total loss.
56
+
57
+ Args:
58
+ logits: (batch, seq_len, vocab_size)
59
+ labels: (batch, seq_len) with token IDs
60
+ equation_module: EquationModule for equation loss
61
+ equation_mask: (batch, seq_len) 1 if token in equation
62
+ domain_logits: (batch, num_domains)
63
+ domain_labels: (batch,)
64
+ citation_module: CitationModule for citation loss
65
+ citation_mask: (batch, seq_len)
66
+ citation_confidence: (batch, seq_len, 1)
67
+ numerical_module: NumericalReasoningModule
68
+ numerical_mask: (batch, seq_len)
69
+
70
+ Returns:
71
+ Dictionary with total loss and component losses
72
+ """
73
+ losses = {}
74
+
75
+ # 1. Language modeling loss (next token prediction)
76
+ lm_loss = F.cross_entropy(
77
+ logits.view(-1, logits.size(-1)),
78
+ labels.view(-1),
79
+ ignore_index=-100, # ignore padding
80
+ )
81
+ losses["lm_loss"] = lm_loss
82
+
83
+ # 2. Equation detection loss
84
+ if equation_module is not None and equation_mask is not None:
85
+ # Need hidden states from equation module - would need to modify forward pass
86
+ # For now, placeholder
87
+ equation_loss = torch.tensor(0.0, device=logits.device)
88
+ losses["equation_loss"] = equation_loss
89
+ else:
90
+ losses["equation_loss"] = torch.tensor(0.0, device=logits.device)
91
+
92
+ # 3. Domain classification loss
93
+ if domain_logits is not None and domain_labels is not None:
94
+ domain_loss = F.cross_entropy(domain_logits, domain_labels)
95
+ losses["domain_loss"] = domain_loss
96
+ else:
97
+ losses["domain_loss"] = torch.tensor(0.0, device=logits.device)
98
+
99
+ # 4. Citation detection loss
100
+ if citation_module is not None and citation_mask is not None and citation_confidence is not None:
101
+ citation_loss = citation_module.compute_citation_loss(
102
+ # Would need hidden states - placeholder
103
+ torch.zeros_like(logits[:, :, :1]), # dummy
104
+ citation_mask,
105
+ citation_confidence,
106
+ )
107
+ losses["citation_loss"] = citation_loss
108
+ else:
109
+ losses["citation_loss"] = torch.tensor(0.0, device=logits.device)
110
+
111
+ # 5. Numerical reasoning loss
112
+ if numerical_module is not None and numerical_mask is not None:
113
+ numerical_loss = numerical_module.compute_numerical_loss(
114
+ torch.zeros_like(logits), # dummy hidden states
115
+ numerical_mask,
116
+ None, # target values
117
+ )
118
+ losses["numerical_loss"] = numerical_loss
119
+ else:
120
+ losses["numerical_loss"] = torch.tensor(0.0, device=logits.device)
121
+
122
+ # Weighted sum
123
+ total_loss = torch.tensor(0.0, device=logits.device)
124
+ for name, loss in losses.items():
125
+ weight = self.loss_weights.get(name, 1.0)
126
+ total_loss = total_loss + loss * weight
127
+
128
+ losses["total_loss"] = total_loss
129
+
130
+ return losses
131
+
132
+
133
+ def test_vortex_loss():
134
+ """Test the loss function."""
135
+ config = {"loss_weights": {
136
+ "lm_loss": 1.0,
137
+ "equation_loss": 0.3,
138
+ "domain_loss": 0.1,
139
+ "citation_loss": 0.1,
140
+ "numerical_loss": 0.2,
141
+ }}
142
+
143
+ loss_fn = VortexLoss(config)
144
+
145
+ batch_size = 2
146
+ seq_len = 128
147
+ vocab_size = 1000
148
+
149
+ logits = torch.randn(batch_size, seq_len, vocab_size)
150
+ labels = torch.randint(0, vocab_size, (batch_size, seq_len))
151
+
152
+ losses = loss_fn(logits, labels)
153
+ print("Losses:")
154
+ for name, value in losses.items():
155
+ print(f" {name}: {value.item():.4f}")
156
+
157
+ assert "total_loss" in losses
158
+ print("VortexLoss test passed!")
159
+
160
+
161
+ if __name__ == "__main__":
162
+ test_vortex_loss()
training/trainer.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Trainer: Main training loop for Vortex model.
3
+ Handles gradient accumulation, mixed precision, checkpointing.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from typing import Optional, Dict, List, Callable
12
+ from pathlib import Path
13
+ import logging
14
+
15
+ from ..training.losses import VortexLoss
16
+ from ..training.curriculum import CurriculumScheduler
17
+
18
+
19
+ class VortexDataset(Dataset):
20
+ """Simple dataset wrapper."""
21
+
22
+ def __init__(
23
+ self,
24
+ shard_files: List[str],
25
+ tokenizer,
26
+ max_seq_len: int = 16384,
27
+ ):
28
+ """
29
+ Initialize dataset.
30
+
31
+ Args:
32
+ shard_files: List of parquet shard files
33
+ tokenizer: Tokenizer for encoding text
34
+ max_seq_len: Maximum sequence length
35
+ """
36
+ self.shard_files = shard_files
37
+ self.tokenizer = tokenizer
38
+ self.max_seq_len = max_seq_len
39
+
40
+ # Load all shards into memory (for simplicity - would stream in practice)
41
+ self.samples = []
42
+ self._load_shards()
43
+
44
+ def _load_shards(self):
45
+ """Load all shards."""
46
+ import pandas as pd
47
+
48
+ for shard in self.shard_files:
49
+ df = pd.read_parquet(shard)
50
+ for _, row in df.iterrows():
51
+ self.samples.append({
52
+ "text": row["text"],
53
+ "dataset": row.get("dataset", ""),
54
+ "domain": row.get("domain", ""),
55
+ })
56
+
57
+ def __len__(self) -> int:
58
+ return len(self.samples)
59
+
60
+ def __getitem__(self, idx) -> Dict:
61
+ sample = self.samples[idx]
62
+ text = sample["text"]
63
+
64
+ # Tokenize
65
+ encoding = self.tokenizer.encode(
66
+ text,
67
+ add_special_tokens=True,
68
+ return_tensors="pt",
69
+ )
70
+
71
+ input_ids = encoding["input_ids"].squeeze(0)
72
+ attention_mask = encoding["attention_mask"].squeeze(0)
73
+
74
+ # Truncate if needed
75
+ if len(input_ids) > self.max_seq_len:
76
+ input_ids = input_ids[:self.max_seq_len]
77
+ attention_mask = attention_mask[:self.max_seq_len]
78
+
79
+ # Labels are same as input_ids (causal LM)
80
+ labels = input_ids.clone()
81
+
82
+ return {
83
+ "input_ids": input_ids,
84
+ "attention_mask": attention_mask,
85
+ "labels": labels,
86
+ "domain": sample["domain"],
87
+ }
88
+
89
+
90
+ class VortexTrainer:
91
+ """
92
+ Main trainer for Vortex model.
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ model: nn.Module,
98
+ tokenizer,
99
+ train_dataset: Dataset,
100
+ config: Dict,
101
+ eval_dataset: Optional[Dataset] = None,
102
+ optimizer: Optional[torch.optim.Optimizer] = None,
103
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
104
+ ):
105
+ """
106
+ Initialize trainer.
107
+
108
+ Args:
109
+ model: VortexModel
110
+ tokenizer: VortexScienceTokenizer
111
+ train_dataset: Training dataset
112
+ config: Training configuration
113
+ eval_dataset: Optional evaluation dataset
114
+ optimizer: Optional optimizer (created if None)
115
+ scheduler: Optional LR scheduler
116
+ """
117
+ self.model = model
118
+ self.tokenizer = tokenizer
119
+ self.train_dataset = train_dataset
120
+ self.eval_dataset = eval_dataset
121
+ self.config = config
122
+
123
+ self.device = torch.device(config["device"])
124
+ self.use_amp = config.get("use_amp", True)
125
+ self.amp_dtype = getattr(torch, config.get("amp_dtype", "bfloat16"))
126
+
127
+ # Move model to device
128
+ self.model.to(self.device)
129
+
130
+ # Setup optimizer
131
+ if optimizer is None:
132
+ self.optimizer = self._create_optimizer()
133
+ else:
134
+ self.optimizer = optimizer
135
+
136
+ # Setup scheduler
137
+ if scheduler is None:
138
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
139
+ self.optimizer,
140
+ T_max=config["max_steps"],
141
+ )
142
+ else:
143
+ self.scheduler = scheduler
144
+
145
+ # Setup AMP scaler
146
+ self.scaler = torch.cuda.amp.GradScaler() if self.use_amp and self.device.type == "cuda" else None
147
+
148
+ # Loss function
149
+ self.loss_fn = VortexLoss(config)
150
+
151
+ # Curriculum scheduler
152
+ self.curriculum = CurriculumScheduler(config, config["max_steps"])
153
+
154
+ # Logging
155
+ self.log_dir = Path(config.get("log_dir", "logs"))
156
+ self.log_dir.mkdir(parents=True, exist_ok=True)
157
+ self.log_interval = config.get("log_interval", 100)
158
+
159
+ # Checkpointing
160
+ self.checkpoint_dir = Path(config.get("checkpoint_dir", "checkpoints"))
161
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
162
+ self.save_interval = config.get("save_interval", 5000)
163
+
164
+ # Training state
165
+ self.global_step = 0
166
+ self.best_eval_loss = float('inf')
167
+
168
+ # Data loader
169
+ self.train_loader = DataLoader(
170
+ train_dataset,
171
+ batch_size=config["micro_batch_size"],
172
+ shuffle=True,
173
+ num_workers=config.get("num_workers", 4),
174
+ pin_memory=config.get("pin_memory", True),
175
+ prefetch_factor=config.get("prefetch_factor", 2),
176
+ )
177
+
178
+ if eval_dataset:
179
+ self.eval_loader = DataLoader(
180
+ eval_dataset,
181
+ batch_size=config["micro_batch_size"],
182
+ shuffle=False,
183
+ num_workers=config.get("num_workers", 4),
184
+ )
185
+
186
+ def _create_optimizer(self) -> torch.optim.Optimizer:
187
+ """Create AdamW optimizer."""
188
+ return torch.optim.AdamW(
189
+ self.model.parameters(),
190
+ lr=self.config["learning_rate"],
191
+ betas=(self.config["beta1"], self.config["beta2"]),
192
+ weight_decay=self.config["weight_decay"],
193
+ )
194
+
195
+ def train_step(
196
+ self,
197
+ batch: Dict,
198
+ current_step: int,
199
+ ) -> Dict[str, torch.Tensor]:
200
+ """
201
+ Single training step.
202
+
203
+ Args:
204
+ batch: Batch dictionary
205
+ current_step: Current step number
206
+
207
+ Returns:
208
+ Dictionary of losses
209
+ """
210
+ self.model.train()
211
+
212
+ # Move batch to device
213
+ input_ids = batch["input_ids"].to(self.device)
214
+ attention_mask = batch["attention_mask"].to(self.device)
215
+ labels = batch["labels"].to(self.device)
216
+
217
+ # Domain info (placeholder - would extract from batch)
218
+ domain_ids = None
219
+ domain_tags = None
220
+
221
+ # Forward pass with AMP
222
+ with torch.cuda.amp.autocast(enabled=self.use_amp and self.device.type == "cuda"):
223
+ outputs = self.model(
224
+ input_ids=input_ids,
225
+ attention_mask=attention_mask,
226
+ domain_ids=domain_ids,
227
+ domain_tags=domain_tags,
228
+ return_dict=True,
229
+ )
230
+ logits = outputs["logits"]
231
+
232
+ # Compute losses
233
+ losses = self.loss_fn(
234
+ logits=logits,
235
+ labels=labels,
236
+ # Pass modules and masks for auxiliary losses
237
+ )
238
+
239
+ # Backward pass
240
+ if self.scaler:
241
+ self.scaler.scale(losses["total_loss"]).backward()
242
+ else:
243
+ losses["total_loss"].backward()
244
+
245
+ return losses
246
+
247
+ def train_epoch(self):
248
+ """Train for one epoch."""
249
+ self.model.train()
250
+
251
+ for batch_idx, batch in enumerate(self.train_loader):
252
+ # Train step
253
+ losses = self.train_step(batch, self.global_step)
254
+
255
+ # Gradient accumulation
256
+ if (self.global_step + 1) % self.config["gradient_accumulation_steps"] == 0:
257
+ # Gradient clipping
258
+ if self.config.get("clip_grad_norm", 0) > 0:
259
+ if self.scaler:
260
+ self.scaler.unscale_(self.optimizer)
261
+ torch.nn.utils.clip_grad_norm_(
262
+ self.model.parameters(),
263
+ self.config["clip_grad_norm"],
264
+ )
265
+
266
+ # Optimizer step
267
+ if self.scaler:
268
+ self.scaler.step(self.optimizer)
269
+ self.scaler.update()
270
+ else:
271
+ self.optimizer.step()
272
+
273
+ self.optimizer.zero_grad()
274
+ self.scheduler.step()
275
+
276
+ # Logging
277
+ if self.global_step % self.log_interval == 0:
278
+ self._log_losses(losses, batch_idx)
279
+
280
+ # Evaluation
281
+ if self.eval_dataset and self.global_step % self.config.get("eval_interval", 1000) == 0:
282
+ self.evaluate()
283
+
284
+ # Checkpointing
285
+ if self.global_step % self.save_interval == 0:
286
+ self.save_checkpoint()
287
+
288
+ self.global_step += 1
289
+
290
+ if self.global_step >= self.config["max_steps"]:
291
+ print("Reached max steps")
292
+ return
293
+
294
+ def evaluate(self) -> Dict[str, float]:
295
+ """Run evaluation."""
296
+ self.model.eval()
297
+ total_loss = 0.0
298
+ num_batches = 0
299
+
300
+ with torch.no_grad():
301
+ for batch in self.eval_loader:
302
+ input_ids = batch["input_ids"].to(self.device)
303
+ attention_mask = batch["attention_mask"].to(self.device)
304
+ labels = batch["labels"].to(self.device)
305
+
306
+ with torch.cuda.amp.autocast(enabled=self.use_amp and self.device.type == "cuda"):
307
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
308
+ logits = outputs["logits"]
309
+ loss = F.cross_entropy(
310
+ logits.view(-1, logits.size(-1)),
311
+ labels.view(-1),
312
+ ignore_index=-100,
313
+ )
314
+
315
+ total_loss += loss.item()
316
+ num_batches += 1
317
+
318
+ avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
319
+ print(f"Evaluation at step {self.global_step}: loss = {avg_loss:.4f}")
320
+
321
+ return {"eval_loss": avg_loss}
322
+
323
+ def save_checkpoint(self, is_best: bool = False):
324
+ """Save model checkpoint."""
325
+ checkpoint = {
326
+ "step": self.global_step,
327
+ "model_state_dict": self.model.state_dict(),
328
+ "optimizer_state_dict": self.optimizer.state_dict(),
329
+ "scheduler_state_dict": self.scheduler.state_dict(),
330
+ "config": self.config,
331
+ "best_eval_loss": self.best_eval_loss,
332
+ }
333
+
334
+ if self.scaler:
335
+ checkpoint["scaler_state_dict"] = self.scaler.state_dict()
336
+
337
+ # Save latest
338
+ checkpoint_path = self.checkpoint_dir / f"checkpoint_{self.global_step:06d}.pt"
339
+ torch.save(checkpoint, checkpoint_path)
340
+ print(f"Saved checkpoint to {checkpoint_path}")
341
+
342
+ # Save best
343
+ if is_best:
344
+ best_path = self.checkpoint_dir / "best_model.pt"
345
+ torch.save(checkpoint, best_path)
346
+ print(f"Saved best model to {best_path}")
347
+
348
+ # Save latest link
349
+ latest_path = self.checkpoint_dir / "latest.pt"
350
+ torch.save(checkpoint, latest_path)
351
+
352
+ def load_checkpoint(self, checkpoint_path: str):
353
+ """Load checkpoint."""
354
+ checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
355
+ self.model.load_state_dict(checkpoint["model_state_dict"])
356
+ self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
357
+ self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
358
+ self.global_step = checkpoint["step"]
359
+ self.best_eval_loss = checkpoint.get("best_eval_loss", float('inf'))
360
+
361
+ if self.scaler and "scaler_state_dict" in checkpoint:
362
+ self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
363
+
364
+ print(f"Loaded checkpoint from {checkpoint_path} at step {self.global_step}")
365
+
366
+ def _log_losses(self, losses: Dict[str, torch.Tensor], batch_idx: int):
367
+ """Log losses to console and file."""
368
+ loss_str = " | ".join([f"{k}: {v.item():.4f}" for k, v in losses.items()])
369
+ print(f"Step {self.global_step} | {loss_str}")
370
+
371
+ def train(self):
372
+ """Main training loop."""
373
+ print("Starting training...")
374
+ print(f"Total steps: {self.config['max_steps']}")
375
+ print(f"Device: {self.device}")
376
+ print(f"Batch size: {self.config['micro_batch_size']}")
377
+ print(f"Gradient accumulation steps: {self.config['gradient_accumulation_steps']}")
378
+
379
+ try:
380
+ self.train_epoch()
381
+ except KeyboardInterrupt:
382
+ print("Training interrupted")
383
+ finally:
384
+ self.save_checkpoint()
385
+
386
+
387
+ def test_trainer():
388
+ """Test trainer with small model."""
389
+ from models.vortex_model import VortexModel
390
+ from tokenizer.vortex_tokenizer import VortexScienceTokenizer
391
+ from configs.vortex_7b_config import VORTEX_7B_CONFIG
392
+
393
+ # Small config for testing
394
+ config = VORTEX_7B_CONFIG.copy()
395
+ config["d_model"] = 256
396
+ config["num_layers"] = 2
397
+ config["num_heads"] = 4
398
+ config["vocab_size"] = 1000
399
+ config["max_steps"] = 10
400
+ config["device"] = "cpu"
401
+
402
+ # Create model
403
+ model = VortexModel(config)
404
+
405
+ # Create dummy tokenizer
406
+ class DummyTokenizer:
407
+ def encode(self, text, add_special_tokens=True, return_tensors="pt"):
408
+ return {"input_ids": torch.randint(0, 1000, (1, 10)), "attention_mask": torch.ones(1, 10)}
409
+
410
+ tokenizer = DummyTokenizer()
411
+
412
+ # Create dummy dataset
413
+ class DummyDataset(torch.utils.data.Dataset):
414
+ def __len__(self): return 10
415
+ def __getitem__(self, idx):
416
+ return {
417
+ "input_ids": torch.randint(0, 1000, (32,)),
418
+ "attention_mask": torch.ones(32),
419
+ "labels": torch.randint(0, 1000, (32,)),
420
+ "domain": "physics",
421
+ }
422
+
423
+ train_dataset = DummyDataset()
424
+ eval_dataset = DummyDataset()
425
+
426
+ # Create trainer
427
+ trainer = VortexTrainer(
428
+ model=model,
429
+ tokenizer=tokenizer,
430
+ train_dataset=train_dataset,
431
+ config=config,
432
+ eval_dataset=eval_dataset,
433
+ )
434
+
435
+ # Run a few steps
436
+ trainer.train()
437
+
438
+ print("Trainer test passed!")
439
+
440
+
441
+ if __name__ == "__main__":
442
+ test_trainer()
vortex_config.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vortex-7B model configuration.
3
+ Optimized for 8GB VRAM (4060 laptop) and MacBook Pro M2/M3.
4
+ """
5
+
6
+ VORTEX_7B_CONFIG = {
7
+ # Model dimensions
8
+ "d_model": 4096,
9
+ "num_layers": 32,
10
+ "num_heads": 32,
11
+ "head_dim": 128, # d_model // num_heads
12
+
13
+ # State-space layer parameters
14
+ "d_state": 16, # SSM state dimension
15
+ "d_conv": 4, # SSM convolution width
16
+
17
+ # Attention parameters
18
+ "window_size": 512, # Local attention window
19
+ "use_flash_attention": True, # CUDA only
20
+
21
+ # Feed-forward parameters
22
+ "ffn_expansion": 4, # Hidden dim = d_model * expansion
23
+ "num_domains": 7, # Physics, Math, Chemistry, Biology, Earth, Space, Zoology
24
+
25
+ # Tokenizer parameters
26
+ "vocab_size": 50000,
27
+ "max_seq_len": 16384,
28
+
29
+ # Layer ratio: 60% SSM, 40% attention
30
+ "ssm_ratio": 0.6,
31
+
32
+ # Data types
33
+ "dtype": "bfloat16",
34
+
35
+ # Special tokens
36
+ "special_tokens": {
37
+ "[PAD]": 0,
38
+ "[UNK]": 1,
39
+ "[BOS]": 2,
40
+ "[EOS]": 3,
41
+ "[EQUATION]": 4,
42
+ "[/EQUATION]": 5,
43
+ "[CITATION]": 6,
44
+ "[/CITATION]": 7,
45
+ "[MOLECULE]": 8,
46
+ "[/MOLECULE]": 9,
47
+ "[FIGURE]": 10,
48
+ "[TABLE]": 11,
49
+ "[MATH]": 12,
50
+ "[CHEM]": 13,
51
+ "[BIO]": 14,
52
+ "[PHYS]": 15,
53
+ "[EARTH]": 16,
54
+ "[SPACE]": 17,
55
+ "[ZOO]": 18,
56
+ },
57
+
58
+ # Domain tags
59
+ "domain_tags": ["[MATH]", "[CHEM]", "[BIO]", "[PHYS]", "[EARTH]", "[SPACE]", "[ZOO]"],
60
+
61
+ # Science module flags (enable/disable for ablation)
62
+ "enable_equation_module": True,
63
+ "enable_numerical_module": True,
64
+ "enable_citation_module": True,
65
+ "enable_molecular_module": True,
66
+ }
67
+
68
+
69
+ def get_config():
70
+ """Return the 7B configuration dictionary."""
71
+ return VORTEX_7B_CONFIG