Upload Vortex model
Browse files- README.md +289 -0
- configs/__pycache__/vortex_7b_config.cpython-313.pyc +0 -0
- configs/training_config.py +97 -0
- configs/vortex_13b_config.py +68 -0
- configs/vortex_7b_config.py +71 -0
- configuration_vortex.py +110 -0
- cuda_optimize.py +287 -0
- data/__pycache__/deduplication.cpython-313.pyc +0 -0
- data/__pycache__/domain_classifier.cpython-313.pyc +0 -0
- data/__pycache__/quality_filter.cpython-313.pyc +0 -0
- data/dataset_loader.py +263 -0
- data/deduplication.py +260 -0
- data/domain_classifier.py +163 -0
- data/quality_filter.py +279 -0
- data/scraper.py +405 -0
- inference.py +213 -0
- modeling_vortex.py +222 -0
- models/__pycache__/attention_layer.cpython-313.pyc +0 -0
- models/__pycache__/scigate_ffn.cpython-313.pyc +0 -0
- models/__pycache__/ssm_layer.cpython-313.pyc +0 -0
- models/__pycache__/vortex_model.cpython-313.pyc +0 -0
- models/attention_layer.py +370 -0
- models/science_modules/__init__.py +15 -0
- models/science_modules/__pycache__/__init__.cpython-313.pyc +0 -0
- models/science_modules/__pycache__/citation_module.cpython-313.pyc +0 -0
- models/science_modules/__pycache__/equation_module.cpython-313.pyc +0 -0
- models/science_modules/__pycache__/molecular_module.cpython-313.pyc +0 -0
- models/science_modules/__pycache__/numerical_module.cpython-313.pyc +0 -0
- models/science_modules/citation_module.py +230 -0
- models/science_modules/equation_module.py +266 -0
- models/science_modules/molecular_module.py +333 -0
- models/science_modules/numerical_module.py +251 -0
- models/scigate_ffn.py +203 -0
- models/ssm_layer.py +252 -0
- models/vortex_model.py +377 -0
- mps_optimize.py +172 -0
- push_to_hf.py +39 -0
- requirements.txt +50 -0
- science_bench.py +360 -0
- test_model.py +449 -0
- tokenization_vortex.py +174 -0
- tokenizer/__pycache__/vortex_tokenizer.cpython-313.pyc +0 -0
- tokenizer/vortex_tokenizer.py +442 -0
- train.py +146 -0
- training/__pycache__/curriculum.cpython-313.pyc +0 -0
- training/__pycache__/losses.cpython-313.pyc +0 -0
- training/curriculum.py +175 -0
- training/losses.py +162 -0
- training/trainer.py +442 -0
- vortex_config.py +71 -0
README.md
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Vortex Scientific
|
| 2 |
+
|
| 3 |
+
**Vortex Scientific** is a from-scratch AI model family designed for deep scientific reasoning. Built from the ground up with a novel hybrid state-space + attention architecture, optimized for consumer laptop hardware (Apple Silicon MacBooks and Nvidia 4060 laptop GPUs).
|
| 4 |
+
|
| 5 |
+
## 🌟 Features
|
| 6 |
+
|
| 7 |
+
- **Novel Architecture**: Hybrid State-Space Model (SSM) + Local Attention blocks
|
| 8 |
+
- **Science-Specialized**: Custom tokenizer, domain-aware gating, and specialized modules for equations, numerical reasoning, citations, and molecular structures
|
| 9 |
+
- **Hardware Optimized**: Runs smoothly on 8GB VRAM (4060 laptop) and 16GB unified memory (MacBook Pro M2/M3)
|
| 10 |
+
- **Two Model Sizes**:
|
| 11 |
+
- **Vortex-7B**: 7 billion parameters, fits in 8GB VRAM
|
| 12 |
+
- **Vortex-13B**: 13 billion parameters, fits in 16GB VRAM with quantization
|
| 13 |
+
- **HuggingFace Compatible**: Full integration with `transformers` library
|
| 14 |
+
- **From Scratch**: No base model — everything built bottom-up including tokenizer and weights
|
| 15 |
+
|
| 16 |
+
## 🏗️ Architecture
|
| 17 |
+
|
| 18 |
+
Vortex uses a two-block hybrid architecture:
|
| 19 |
+
|
| 20 |
+
1. **SSM-Only Blocks**: State-space layers for efficient long-context processing (O(n) complexity)
|
| 21 |
+
2. **Attention+Science Blocks**: Local windowed attention + science modules + SciGate FFN
|
| 22 |
+
|
| 23 |
+
Layer ratios:
|
| 24 |
+
- 7B: 60% SSM, 40% Attention (pattern: SSM, SSM, Attn, ...)
|
| 25 |
+
- 13B: 50% SSM, 50% Attention (pattern: SSM, Attn, SSM, Attn, ...)
|
| 26 |
+
|
| 27 |
+
### Science Modules
|
| 28 |
+
|
| 29 |
+
- **EquationModule**: LaTeX equation detection and structural understanding
|
| 30 |
+
- **NumericalReasoningModule**: Digit-level encoding, scientific notation, unit awareness
|
| 31 |
+
- **CitationModule**: Citation span detection, provenance tracking, confidence scoring
|
| 32 |
+
- **MolecularModule**: Element embeddings, SMILES understanding, amino acid sequences
|
| 33 |
+
|
| 34 |
+
## 📦 Project Structure
|
| 35 |
+
|
| 36 |
+
```
|
| 37 |
+
Vortex/
|
| 38 |
+
├── configs/
|
| 39 |
+
│ ├── vortex_7b_config.py # 7B model configuration
|
| 40 |
+
│ ├── vortex_13b_config.py # 13B model configuration
|
| 41 |
+
│ └── training_config.py # Training hyperparameters
|
| 42 |
+
├── models/
|
| 43 |
+
│ ├── ssm_layer.py # State-space layer
|
| 44 |
+
│ ├── attention_layer.py # Local windowed attention
|
| 45 |
+
│ ├── scigate_ffn.py # Science-gated feed-forward
|
| 46 |
+
│ ├── vortex_model.py # Main model class
|
| 47 |
+
│ └── science_modules/ # Specialized science modules
|
| 48 |
+
├── tokenizer/
|
| 49 |
+
│ └── vortex_tokenizer.py # Custom BPE tokenizer with science vocab
|
| 50 |
+
├── data/
|
| 51 |
+
│ ├── dataset_loader.py # Open dataset loading (Pile, S2ORC, etc.)
|
| 52 |
+
│ ├── quality_filter.py # Multi-stage quality filtering
|
| 53 |
+
│ ├── domain_classifier.py # 7-domain classifier
|
| 54 |
+
│ ├── deduplication.py # MinHash LSH deduplication
|
| 55 |
+
│ └── scraper.py # Web scraping (arXiv, PubMed, etc.)
|
| 56 |
+
├── training/
|
| 57 |
+
│ ├── trainer.py # Main training loop
|
| 58 |
+
│ ├── losses.py # Science-aware loss functions
|
| 59 |
+
│ └── curriculum.py # Curriculum learning scheduler
|
| 60 |
+
├── inference/
|
| 61 |
+
│ ├── cuda_optimize.py # CUDA optimizations (Flash Attention, INT8)
|
| 62 |
+
│ └── mps_optimize.py # MPS optimizations for Apple Silicon
|
| 63 |
+
├── evaluation/ # Science benchmarks (coming soon)
|
| 64 |
+
├── configuration_vortex.py # HF config class
|
| 65 |
+
├── tokenization_vortex.py # HF tokenizer wrapper
|
| 66 |
+
├── modeling_vortex.py # HF model integration
|
| 67 |
+
├── train.py # Training entry point
|
| 68 |
+
├── inference/inference.py # Inference entry point
|
| 69 |
+
└── requirements.txt
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
## 🚀 Quick Start
|
| 73 |
+
|
| 74 |
+
### Installation
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
# Clone and setup
|
| 78 |
+
cd Vortex
|
| 79 |
+
pip install -r requirements.txt
|
| 80 |
+
|
| 81 |
+
# For CUDA optimizations
|
| 82 |
+
pip install flash-attn
|
| 83 |
+
pip install bitsandbytes
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### Training
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
# Train 7B model on CUDA
|
| 90 |
+
python train.py \
|
| 91 |
+
--model_size 7b \
|
| 92 |
+
--device cuda \
|
| 93 |
+
--data_dir ./data/processed \
|
| 94 |
+
--output_dir ./checkpoints \
|
| 95 |
+
--max_steps 100000
|
| 96 |
+
|
| 97 |
+
# Train 13B model with INT8 quantization (for 8GB VRAM)
|
| 98 |
+
python train.py \
|
| 99 |
+
--model_size 13b \
|
| 100 |
+
--device cuda \
|
| 101 |
+
--quantization int8 \
|
| 102 |
+
--data_dir ./data/processed \
|
| 103 |
+
--output_dir ./checkpoints_13b
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
### Inference
|
| 107 |
+
|
| 108 |
+
```bash
|
| 109 |
+
# Generate text with 7B model
|
| 110 |
+
python inference/inference.py \
|
| 111 |
+
--model_path ./checkpoints/latest.pt \
|
| 112 |
+
--model_size 7b \
|
| 113 |
+
--device cuda \
|
| 114 |
+
--prompt "The equation E = mc^2 describes" \
|
| 115 |
+
--max_new_tokens 100
|
| 116 |
+
|
| 117 |
+
# Interactive mode
|
| 118 |
+
python inference/inference.py \
|
| 119 |
+
--model_path ./checkpoints/latest.pt \
|
| 120 |
+
--model_size 7b \
|
| 121 |
+
--device cuda \
|
| 122 |
+
--interactive
|
| 123 |
+
|
| 124 |
+
# On Apple Silicon (MPS)
|
| 125 |
+
python inference/inference.py \
|
| 126 |
+
--model_path ./checkpoints/latest.pt \
|
| 127 |
+
--model_size 7b \
|
| 128 |
+
--use_mps \
|
| 129 |
+
--prompt "Explain quantum mechanics"
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
### HuggingFace Integration
|
| 133 |
+
|
| 134 |
+
```python
|
| 135 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 136 |
+
|
| 137 |
+
# Load model and tokenizer
|
| 138 |
+
model = AutoModelForCausalLM.from_pretrained("./checkpoints")
|
| 139 |
+
tokenizer = AutoTokenizer.from_pretrained("./checkpoints")
|
| 140 |
+
|
| 141 |
+
# Generate
|
| 142 |
+
input_text = "The energy of a photon is given by"
|
| 143 |
+
inputs = tokenizer(input_text, return_tensors="pt")
|
| 144 |
+
outputs = model.generate(**inputs, max_new_tokens=50)
|
| 145 |
+
print(tokenizer.decode(outputs[0]))
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
## 📊 Data Pipeline
|
| 149 |
+
|
| 150 |
+
1. **Open Datasets**: Automatically download from HuggingFace (Pile, S2ORC, Math datasets, PubMed QA)
|
| 151 |
+
2. **Quality Filtering**: Multi-stage checks (length, language, equations, repetition, citations)
|
| 152 |
+
3. **Deduplication**: MinHash LSH for near-duplicate detection
|
| 153 |
+
4. **Domain Classification**: Classify into 7 science domains
|
| 154 |
+
5. **Tokenization**: Custom science-aware BPE tokenizer
|
| 155 |
+
6. **Sharding**: Write to Parquet with statistics
|
| 156 |
+
|
| 157 |
+
```python
|
| 158 |
+
from data.dataset_loader import VortexDatasetLoader
|
| 159 |
+
from data.quality_filter import ScienceQualityFilter
|
| 160 |
+
from data.deduplication import MinHashLSH
|
| 161 |
+
|
| 162 |
+
# Load and process data
|
| 163 |
+
loader = VortexDatasetLoader()
|
| 164 |
+
quality_filter = ScienceQualityFilter()
|
| 165 |
+
lsh = MinHashLSH()
|
| 166 |
+
|
| 167 |
+
# Stream datasets, filter, deduplicate, and shard
|
| 168 |
+
for sample in loader.load_multiple_datasets(["pile_scientific", "automath"]):
|
| 169 |
+
if quality_filter.filter(sample["text"]):
|
| 170 |
+
lsh.add_document(sample["id"], sample["text"])
|
| 171 |
+
# Tokenize and save
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
## 🎯 Training Strategy
|
| 175 |
+
|
| 176 |
+
### Curriculum Learning
|
| 177 |
+
|
| 178 |
+
Training progresses through 4 stages:
|
| 179 |
+
|
| 180 |
+
1. **Foundation** (0-20%): Basic science text, simple equations, definitions
|
| 181 |
+
2. **Domain** (20-50%): Domain-specific deep content per science area
|
| 182 |
+
3. **Reasoning** (50-80%): Scientific problem solving, multi-step derivations
|
| 183 |
+
4. **Integration** (80-100%): Cross-domain science, full dataset
|
| 184 |
+
|
| 185 |
+
### Science-Aware Loss
|
| 186 |
+
|
| 187 |
+
```python
|
| 188 |
+
total_loss = (
|
| 189 |
+
lm_loss * 1.0 # Standard next token prediction
|
| 190 |
+
+ equation_loss * 0.3 # Equation reconstruction accuracy
|
| 191 |
+
+ domain_loss * 0.1 # Domain classification head
|
| 192 |
+
+ citation_loss * 0.1 # Citation detection accuracy
|
| 193 |
+
+ numerical_loss * 0.2 # Numerical reasoning accuracy
|
| 194 |
+
)
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
## ⚙️ Configuration
|
| 198 |
+
|
| 199 |
+
### 7B Config (VORTEX_7B_CONFIG)
|
| 200 |
+
|
| 201 |
+
- `d_model`: 4096
|
| 202 |
+
- `num_layers`: 32
|
| 203 |
+
- `num_heads`: 32
|
| 204 |
+
- `d_state`: 16
|
| 205 |
+
- `ssm_ratio`: 0.6
|
| 206 |
+
- `vocab_size`: 50000
|
| 207 |
+
- `max_seq_len`: 16384
|
| 208 |
+
|
| 209 |
+
### 13B Config (VORTEX_13B_CONFIG)
|
| 210 |
+
|
| 211 |
+
- `d_model`: 5120
|
| 212 |
+
- `num_layers`: 40
|
| 213 |
+
- `num_heads`: 40
|
| 214 |
+
- `d_state`: 32
|
| 215 |
+
- `ssm_ratio`: 0.5
|
| 216 |
+
- `vocab_size`: 50000
|
| 217 |
+
- `max_seq_len`: 16384
|
| 218 |
+
|
| 219 |
+
## 🔧 Hardware Targets
|
| 220 |
+
|
| 221 |
+
### Nvidia 4060 Laptop (8GB VRAM)
|
| 222 |
+
|
| 223 |
+
- **7B**: BF16, no quantization, Flash Attention 2, torch.compile
|
| 224 |
+
- **13B**: INT8 quantization, Flash Attention 2, torch.compile
|
| 225 |
+
- Target TPS: 25-40 (7B), 15-25 (13B)
|
| 226 |
+
|
| 227 |
+
### Apple Silicon (M2/M3)
|
| 228 |
+
|
| 229 |
+
- **7B on M3**: BF16 (via float16), SDPA, no compile
|
| 230 |
+
- **13B on M3 Max**: BF16, unified memory, SDPA
|
| 231 |
+
- Target TPS: 20-35 (7B), 12-20 (13B)
|
| 232 |
+
|
| 233 |
+
## 🧪 Science Domains
|
| 234 |
+
|
| 235 |
+
1. **Physics** (`[PHYS]`)
|
| 236 |
+
2. **Mathematics** (`[MATH]`)
|
| 237 |
+
3. **Chemistry** (`[CHEM]`)
|
| 238 |
+
4. **Biology** (`[BIO]`)
|
| 239 |
+
5. **Earth Science** (`[EARTH]`)
|
| 240 |
+
6. **Space Science** (`[SPACE]`)
|
| 241 |
+
7. **Zoology** (`[ZOO]`)
|
| 242 |
+
|
| 243 |
+
Domain tags can be included in training data to guide the SciGate FFN routing.
|
| 244 |
+
|
| 245 |
+
## 📝 Tokenizer
|
| 246 |
+
|
| 247 |
+
Custom BPE tokenizer with:
|
| 248 |
+
|
| 249 |
+
- 40,000 base BPE tokens trained on scientific corpus
|
| 250 |
+
- 10,000 science-specific tokens:
|
| 251 |
+
- 500 LaTeX math symbols (`\alpha`, `\sum`, `\int`, etc.)
|
| 252 |
+
- 118 chemical element symbols
|
| 253 |
+
- 200 SI and derived units
|
| 254 |
+
- 300 scientific abbreviations (DNA, RNA, ATP, etc.)
|
| 255 |
+
- 500 mathematical operators
|
| 256 |
+
- Amino acid codes
|
| 257 |
+
- Greek alphabet (α, β, γ, etc.)
|
| 258 |
+
- Special tokens: `[EQUATION]`, `[CITATION]`, `[MOLECULE]`, `[FIGURE]`, `[TABLE]`, domain tags
|
| 259 |
+
|
| 260 |
+
## 🧪 Evaluation
|
| 261 |
+
|
| 262 |
+
Science benchmarks across all 7 domains will be added. Planned benchmarks:
|
| 263 |
+
|
| 264 |
+
- **Physics**: Feynman Questions, Physics GRE
|
| 265 |
+
- **Math**: MATH dataset, GSM8K
|
| 266 |
+
- **Chemistry**: Chemistry problem-solving, molecular property prediction
|
| 267 |
+
- **Biology**: PubMed QA, bioinformatics tasks
|
| 268 |
+
- **Earth Science**: Climate modeling questions
|
| 269 |
+
- **Space Science**: Astronomy problem sets
|
| 270 |
+
- **Zoology**: Species classification, ecological reasoning
|
| 271 |
+
|
| 272 |
+
## 📄 License
|
| 273 |
+
|
| 274 |
+
This is a school science project. Code is provided for educational purposes.
|
| 275 |
+
|
| 276 |
+
## 🙏 Acknowledgments
|
| 277 |
+
|
| 278 |
+
- **Mamba** (Gu et al.) for SSM architecture inspiration
|
| 279 |
+
- **Flash Attention** (Dao et al.) for efficient attention
|
| 280 |
+
- **HuggingFace** for transformers library
|
| 281 |
+
- All open scientific data sources: arXiv, PubMed, S2ORC, etc.
|
| 282 |
+
|
| 283 |
+
## 📧 Contact
|
| 284 |
+
|
| 285 |
+
For questions or issues, please open an issue on GitHub.
|
| 286 |
+
|
| 287 |
+
---
|
| 288 |
+
|
| 289 |
+
**Built with ❤️ for scientific AI research**
|
configs/__pycache__/vortex_7b_config.cpython-313.pyc
ADDED
|
Binary file (1.78 kB). View file
|
|
|
configs/training_config.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training configuration for Vortex models.
|
| 3 |
+
Covers both 7B and 13B variants with hardware-specific optimizations.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
TRAINING_CONFIG = {
|
| 9 |
+
# Training hyperparameters
|
| 10 |
+
"learning_rate": 3e-4,
|
| 11 |
+
"weight_decay": 0.1,
|
| 12 |
+
"beta1": 0.9,
|
| 13 |
+
"beta2": 0.95,
|
| 14 |
+
"clip_grad_norm": 1.0,
|
| 15 |
+
|
| 16 |
+
# Batch sizing
|
| 17 |
+
"global_batch_size": 512, # tokens per batch
|
| 18 |
+
"micro_batch_size": 8, # per GPU
|
| 19 |
+
"gradient_accumulation_steps": 4,
|
| 20 |
+
|
| 21 |
+
# Training schedule
|
| 22 |
+
"max_steps": 100000,
|
| 23 |
+
"warmup_steps": 2000,
|
| 24 |
+
"save_interval": 5000,
|
| 25 |
+
"eval_interval": 1000,
|
| 26 |
+
"log_interval": 100,
|
| 27 |
+
|
| 28 |
+
# Mixed precision
|
| 29 |
+
"use_amp": True,
|
| 30 |
+
"amp_dtype": torch.bfloat16,
|
| 31 |
+
|
| 32 |
+
# Optimizer
|
| 33 |
+
"optimizer": "AdamW",
|
| 34 |
+
"use_fused": True, # fused AdamW if available
|
| 35 |
+
|
| 36 |
+
# Curriculum learning stages (as fractions of max_steps)
|
| 37 |
+
"curriculum_stages": [
|
| 38 |
+
{"name": "foundation", "start": 0.0, "end": 0.2}, # 0-20%
|
| 39 |
+
{"name": "domain", "start": 0.2, "end": 0.5}, # 20-50%
|
| 40 |
+
{"name": "reasoning", "start": 0.5, "end": 0.8}, # 50-80%
|
| 41 |
+
{"name": "integration", "start": 0.8, "end": 1.0}, # 80-100%
|
| 42 |
+
],
|
| 43 |
+
|
| 44 |
+
# Loss weights (science-aware loss)
|
| 45 |
+
"loss_weights": {
|
| 46 |
+
"lm_loss": 1.0,
|
| 47 |
+
"equation_loss": 0.3,
|
| 48 |
+
"domain_loss": 0.1,
|
| 49 |
+
"citation_loss": 0.1,
|
| 50 |
+
"numerical_loss": 0.2,
|
| 51 |
+
},
|
| 52 |
+
|
| 53 |
+
# Checkpointing
|
| 54 |
+
"checkpoint_dir": "checkpoints",
|
| 55 |
+
"save_optimizer_state": True,
|
| 56 |
+
"save_scheduler_state": True,
|
| 57 |
+
|
| 58 |
+
# Logging
|
| 59 |
+
"log_dir": "logs",
|
| 60 |
+
"use_wandb": False,
|
| 61 |
+
"wandb_project": "vortex-scientific",
|
| 62 |
+
|
| 63 |
+
# Data loading
|
| 64 |
+
"num_workers": 8,
|
| 65 |
+
"prefetch_factor": 2,
|
| 66 |
+
"pin_memory": True,
|
| 67 |
+
|
| 68 |
+
# Device configuration
|
| 69 |
+
"device": "cuda", # or "mps" for Apple Silicon
|
| 70 |
+
"use_mps": False,
|
| 71 |
+
|
| 72 |
+
# Quantization (for 13B on 8GB VRAM)
|
| 73 |
+
"quantization": None, # None, "int8", "int4"
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
# Hardware-specific overrides
|
| 77 |
+
TRAINING_CONFIG_7B_CUDA = TRAINING_CONFIG.copy()
|
| 78 |
+
TRAINING_CONFIG_7B_CUDA.update({
|
| 79 |
+
"device": "cuda",
|
| 80 |
+
"quantization": None,
|
| 81 |
+
"micro_batch_size": 8,
|
| 82 |
+
})
|
| 83 |
+
|
| 84 |
+
TRAINING_CONFIG_13B_CUDA = TRAINING_CONFIG.copy()
|
| 85 |
+
TRAINING_CONFIG_13B_CUDA.update({
|
| 86 |
+
"device": "cuda",
|
| 87 |
+
"quantization": "int8", # 13B needs INT8 on 8GB
|
| 88 |
+
"micro_batch_size": 4,
|
| 89 |
+
})
|
| 90 |
+
|
| 91 |
+
TRAINING_CONFIG_MPS = TRAINING_CONFIG.copy()
|
| 92 |
+
TRAINING_CONFIG_MPS.update({
|
| 93 |
+
"device": "mps",
|
| 94 |
+
"use_mps": True,
|
| 95 |
+
"use_amp": False, # MPS doesn't support bfloat16 AMP well
|
| 96 |
+
"micro_batch_size": 4,
|
| 97 |
+
})
|
configs/vortex_13b_config.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vortex-13B model configuration.
|
| 3 |
+
Optimized for 16GB VRAM (4060 Ti laptop) and MacBook Pro M3 Max.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
VORTEX_13B_CONFIG = {
|
| 7 |
+
# Model dimensions
|
| 8 |
+
"d_model": 5120,
|
| 9 |
+
"num_layers": 40,
|
| 10 |
+
"num_heads": 40,
|
| 11 |
+
"head_dim": 128, # d_model // num_heads
|
| 12 |
+
|
| 13 |
+
# State-space layer parameters
|
| 14 |
+
"d_state": 32, # SSM state dimension (larger for bigger model)
|
| 15 |
+
"d_conv": 4, # SSM convolution width
|
| 16 |
+
|
| 17 |
+
# Attention parameters
|
| 18 |
+
"window_size": 512, # Local attention window
|
| 19 |
+
"use_flash_attention": True,
|
| 20 |
+
|
| 21 |
+
# Feed-forward parameters
|
| 22 |
+
"ffn_expansion": 4,
|
| 23 |
+
"num_domains": 7,
|
| 24 |
+
"vocab_size": 50000,
|
| 25 |
+
"max_seq_len": 16384,
|
| 26 |
+
|
| 27 |
+
# Layer ratio: 50% SSM, 50% attention (more memory for attention)
|
| 28 |
+
"ssm_ratio": 0.5,
|
| 29 |
+
|
| 30 |
+
# Data types
|
| 31 |
+
"dtype": "bfloat16",
|
| 32 |
+
|
| 33 |
+
# Special tokens (same as 7B)
|
| 34 |
+
"special_tokens": {
|
| 35 |
+
"[PAD]": 0,
|
| 36 |
+
"[UNK]": 1,
|
| 37 |
+
"[BOS]": 2,
|
| 38 |
+
"[EOS]": 3,
|
| 39 |
+
"[EQUATION]": 4,
|
| 40 |
+
"[/EQUATION]": 5,
|
| 41 |
+
"[CITATION]": 6,
|
| 42 |
+
"[/CITATION]": 7,
|
| 43 |
+
"[MOLECULE]": 8,
|
| 44 |
+
"[/MOLECULE]": 9,
|
| 45 |
+
"[FIGURE]": 10,
|
| 46 |
+
"[TABLE]": 11,
|
| 47 |
+
"[MATH]": 12,
|
| 48 |
+
"[CHEM]": 13,
|
| 49 |
+
"[BIO]": 14,
|
| 50 |
+
"[PHYS]": 15,
|
| 51 |
+
"[EARTH]": 16,
|
| 52 |
+
"[SPACE]": 17,
|
| 53 |
+
"[ZOO]": 18,
|
| 54 |
+
},
|
| 55 |
+
|
| 56 |
+
"domain_tags": ["[MATH]", "[CHEM]", "[BIO]", "[PHYS]", "[EARTH]", "[SPACE]", "[ZOO]"],
|
| 57 |
+
|
| 58 |
+
# Science module flags
|
| 59 |
+
"enable_equation_module": True,
|
| 60 |
+
"enable_numerical_module": True,
|
| 61 |
+
"enable_citation_module": True,
|
| 62 |
+
"enable_molecular_module": True,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_config():
|
| 67 |
+
"""Return the 13B configuration dictionary."""
|
| 68 |
+
return VORTEX_13B_CONFIG
|
configs/vortex_7b_config.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vortex-7B model configuration.
|
| 3 |
+
Optimized for 8GB VRAM (4060 laptop) and MacBook Pro M2/M3.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
VORTEX_7B_CONFIG = {
|
| 7 |
+
# Model dimensions
|
| 8 |
+
"d_model": 4096,
|
| 9 |
+
"num_layers": 32,
|
| 10 |
+
"num_heads": 32,
|
| 11 |
+
"head_dim": 128, # d_model // num_heads
|
| 12 |
+
|
| 13 |
+
# State-space layer parameters
|
| 14 |
+
"d_state": 16, # SSM state dimension
|
| 15 |
+
"d_conv": 4, # SSM convolution width
|
| 16 |
+
|
| 17 |
+
# Attention parameters
|
| 18 |
+
"window_size": 512, # Local attention window
|
| 19 |
+
"use_flash_attention": True, # CUDA only
|
| 20 |
+
|
| 21 |
+
# Feed-forward parameters
|
| 22 |
+
"ffn_expansion": 4, # Hidden dim = d_model * expansion
|
| 23 |
+
"num_domains": 7, # Physics, Math, Chemistry, Biology, Earth, Space, Zoology
|
| 24 |
+
|
| 25 |
+
# Tokenizer parameters
|
| 26 |
+
"vocab_size": 50000,
|
| 27 |
+
"max_seq_len": 16384,
|
| 28 |
+
|
| 29 |
+
# Layer ratio: 60% SSM, 40% attention
|
| 30 |
+
"ssm_ratio": 0.6,
|
| 31 |
+
|
| 32 |
+
# Data types
|
| 33 |
+
"dtype": "bfloat16",
|
| 34 |
+
|
| 35 |
+
# Special tokens
|
| 36 |
+
"special_tokens": {
|
| 37 |
+
"[PAD]": 0,
|
| 38 |
+
"[UNK]": 1,
|
| 39 |
+
"[BOS]": 2,
|
| 40 |
+
"[EOS]": 3,
|
| 41 |
+
"[EQUATION]": 4,
|
| 42 |
+
"[/EQUATION]": 5,
|
| 43 |
+
"[CITATION]": 6,
|
| 44 |
+
"[/CITATION]": 7,
|
| 45 |
+
"[MOLECULE]": 8,
|
| 46 |
+
"[/MOLECULE]": 9,
|
| 47 |
+
"[FIGURE]": 10,
|
| 48 |
+
"[TABLE]": 11,
|
| 49 |
+
"[MATH]": 12,
|
| 50 |
+
"[CHEM]": 13,
|
| 51 |
+
"[BIO]": 14,
|
| 52 |
+
"[PHYS]": 15,
|
| 53 |
+
"[EARTH]": 16,
|
| 54 |
+
"[SPACE]": 17,
|
| 55 |
+
"[ZOO]": 18,
|
| 56 |
+
},
|
| 57 |
+
|
| 58 |
+
# Domain tags
|
| 59 |
+
"domain_tags": ["[MATH]", "[CHEM]", "[BIO]", "[PHYS]", "[EARTH]", "[SPACE]", "[ZOO]"],
|
| 60 |
+
|
| 61 |
+
# Science module flags (enable/disable for ablation)
|
| 62 |
+
"enable_equation_module": True,
|
| 63 |
+
"enable_numerical_module": True,
|
| 64 |
+
"enable_citation_module": True,
|
| 65 |
+
"enable_molecular_module": True,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_config():
|
| 70 |
+
"""Return the 7B configuration dictionary."""
|
| 71 |
+
return VORTEX_7B_CONFIG
|
configuration_vortex.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vortex configuration for HuggingFace.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Optional, List, Dict, Any
|
| 6 |
+
from transformers import PretrainedConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class VortexConfig(PretrainedConfig):
|
| 10 |
+
"""
|
| 11 |
+
Configuration class for Vortex model.
|
| 12 |
+
Compatible with HuggingFace transformers.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
model_type = "vortex"
|
| 16 |
+
tie_word_embeddings = True
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
d_model: int = 4096,
|
| 21 |
+
num_layers: int = 32,
|
| 22 |
+
num_heads: int = 32,
|
| 23 |
+
d_state: int = 16,
|
| 24 |
+
d_conv: int = 4,
|
| 25 |
+
window_size: int = 512,
|
| 26 |
+
ffn_expansion: int = 4,
|
| 27 |
+
num_domains: int = 7,
|
| 28 |
+
vocab_size: int = 50000,
|
| 29 |
+
max_seq_len: int = 16384,
|
| 30 |
+
ssm_ratio: float = 0.6,
|
| 31 |
+
enable_equation_module: bool = True,
|
| 32 |
+
enable_numerical_module: bool = True,
|
| 33 |
+
enable_citation_module: bool = True,
|
| 34 |
+
enable_molecular_module: bool = True,
|
| 35 |
+
special_tokens: Optional[Dict[str, int]] = None,
|
| 36 |
+
domain_tags: Optional[List[str]] = None,
|
| 37 |
+
initializer_range: float = 0.02,
|
| 38 |
+
tie_word_embeddings: bool = True,
|
| 39 |
+
**kwargs
|
| 40 |
+
):
|
| 41 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
| 42 |
+
self.d_model = d_model
|
| 43 |
+
self.num_layers = num_layers
|
| 44 |
+
self.num_heads = num_heads
|
| 45 |
+
self.d_state = d_state
|
| 46 |
+
self.d_conv = d_conv
|
| 47 |
+
self.window_size = window_size
|
| 48 |
+
self.ffn_expansion = ffn_expansion
|
| 49 |
+
self.num_domains = num_domains
|
| 50 |
+
self.vocab_size = vocab_size
|
| 51 |
+
self.max_seq_len = max_seq_len
|
| 52 |
+
self.ssm_ratio = ssm_ratio
|
| 53 |
+
self.enable_equation_module = enable_equation_module
|
| 54 |
+
self.enable_numerical_module = enable_numerical_module
|
| 55 |
+
self.enable_citation_module = enable_citation_module
|
| 56 |
+
self.enable_molecular_module = enable_molecular_module
|
| 57 |
+
self.special_tokens = special_tokens or {
|
| 58 |
+
"[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3,
|
| 59 |
+
"[EQUATION]": 4, "[/EQUATION]": 5,
|
| 60 |
+
"[CITATION]": 6, "[/CITATION]": 7,
|
| 61 |
+
"[MOLECULE]": 8, "[/MOLECULE]": 9,
|
| 62 |
+
"[FIGURE]": 10, "[TABLE]": 11,
|
| 63 |
+
"[MATH]": 12, "[CHEM]": 13, "[BIO]": 14,
|
| 64 |
+
"[PHYS]": 15, "[EARTH]": 16, "[SPACE]": 17, "[ZOO]": 18,
|
| 65 |
+
}
|
| 66 |
+
self.domain_tags = domain_tags or ["[MATH]", "[CHEM]", "[BIO]", "[PHYS]", "[EARTH]", "[SPACE]", "[ZOO]"]
|
| 67 |
+
self.initializer_range = initializer_range
|
| 68 |
+
# Compute derived attributes
|
| 69 |
+
self.head_dim = self.d_model // self.num_heads
|
| 70 |
+
|
| 71 |
+
@classmethod
|
| 72 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
| 73 |
+
"""Load config from pretrained model."""
|
| 74 |
+
import json
|
| 75 |
+
import os
|
| 76 |
+
|
| 77 |
+
config_path = os.path.join(pretrained_model_name_or_path, "config.json")
|
| 78 |
+
if os.path.exists(config_path):
|
| 79 |
+
with open(config_path, "r") as f:
|
| 80 |
+
config_dict = json.load(f)
|
| 81 |
+
config_dict.update(kwargs)
|
| 82 |
+
return cls(**config_dict)
|
| 83 |
+
else:
|
| 84 |
+
# Return default config
|
| 85 |
+
return cls(**kwargs)
|
| 86 |
+
|
| 87 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 88 |
+
"""Convert to dictionary."""
|
| 89 |
+
return {
|
| 90 |
+
"model_type": self.model_type,
|
| 91 |
+
"d_model": self.d_model,
|
| 92 |
+
"num_layers": self.num_layers,
|
| 93 |
+
"num_heads": self.num_heads,
|
| 94 |
+
"head_dim": self.head_dim,
|
| 95 |
+
"d_state": self.d_state,
|
| 96 |
+
"d_conv": self.d_conv,
|
| 97 |
+
"window_size": self.window_size,
|
| 98 |
+
"ffn_expansion": self.ffn_expansion,
|
| 99 |
+
"num_domains": self.num_domains,
|
| 100 |
+
"vocab_size": self.vocab_size,
|
| 101 |
+
"max_seq_len": self.max_seq_len,
|
| 102 |
+
"ssm_ratio": self.ssm_ratio,
|
| 103 |
+
"enable_equation_module": self.enable_equation_module,
|
| 104 |
+
"enable_numerical_module": self.enable_numerical_module,
|
| 105 |
+
"enable_citation_module": self.enable_citation_module,
|
| 106 |
+
"enable_molecular_module": self.enable_molecular_module,
|
| 107 |
+
"special_tokens": self.special_tokens,
|
| 108 |
+
"domain_tags": self.domain_tags,
|
| 109 |
+
"initializer_range": self.initializer_range,
|
| 110 |
+
}
|
cuda_optimize.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CUDA optimizations for Vortex model on Nvidia 4060 laptop.
|
| 3 |
+
Flash Attention 2, torch.compile, INT8 quantization.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from typing import Optional, Dict, Any
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def optimize_for_cuda(
|
| 12 |
+
model: nn.Module,
|
| 13 |
+
config: Dict,
|
| 14 |
+
use_flash_attention: bool = True,
|
| 15 |
+
use_torch_compile: bool = True,
|
| 16 |
+
compile_mode: str = "reduce-overhead",
|
| 17 |
+
quantization: Optional[str] = None,
|
| 18 |
+
) -> nn.Module:
|
| 19 |
+
"""
|
| 20 |
+
Apply CUDA optimizations to model.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
model: VortexModel
|
| 24 |
+
config: Model config
|
| 25 |
+
use_flash_attention: Enable Flash Attention 2
|
| 26 |
+
use_torch_compile: Use torch.compile
|
| 27 |
+
compile_mode: Compile mode ("reduce-overhead", "max-autotune")
|
| 28 |
+
quantization: None, "int8", or "int4"
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Optimized model
|
| 32 |
+
"""
|
| 33 |
+
device = torch.device("cuda")
|
| 34 |
+
|
| 35 |
+
# Move to CUDA
|
| 36 |
+
model = model.to(device)
|
| 37 |
+
|
| 38 |
+
# Set dtype
|
| 39 |
+
dtype_str = config.get("dtype", "bfloat16")
|
| 40 |
+
if dtype_str == "bfloat16":
|
| 41 |
+
dtype = torch.bfloat16
|
| 42 |
+
elif dtype_str == "float16":
|
| 43 |
+
dtype = torch.float16
|
| 44 |
+
else:
|
| 45 |
+
dtype = torch.float32
|
| 46 |
+
|
| 47 |
+
model = model.to(dtype)
|
| 48 |
+
|
| 49 |
+
# Apply Flash Attention 2 to attention layers
|
| 50 |
+
if use_flash_attention:
|
| 51 |
+
model = _apply_flash_attention(model)
|
| 52 |
+
print("Applied Flash Attention 2")
|
| 53 |
+
|
| 54 |
+
# Apply torch.compile
|
| 55 |
+
if use_torch_compile:
|
| 56 |
+
model = torch.compile(
|
| 57 |
+
model,
|
| 58 |
+
mode=compile_mode,
|
| 59 |
+
fullgraph=True,
|
| 60 |
+
dynamic=True,
|
| 61 |
+
)
|
| 62 |
+
print(f"Applied torch.compile with mode={compile_mode}")
|
| 63 |
+
|
| 64 |
+
# Apply quantization if requested
|
| 65 |
+
if quantization == "int8":
|
| 66 |
+
model = _apply_int8_quantization(model)
|
| 67 |
+
print("Applied INT8 quantization")
|
| 68 |
+
elif quantization == "int4":
|
| 69 |
+
model = _apply_int4_quantization(model)
|
| 70 |
+
print("Applied INT4 quantization")
|
| 71 |
+
|
| 72 |
+
return model
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _apply_flash_attention(model: nn.Module) -> nn.Module:
|
| 76 |
+
"""
|
| 77 |
+
Replace standard attention with Flash Attention 2.
|
| 78 |
+
Requires: pip install flash-attn
|
| 79 |
+
"""
|
| 80 |
+
try:
|
| 81 |
+
from flash_attn import flash_attn_func
|
| 82 |
+
|
| 83 |
+
# Monkey-patch attention layers to use flash attention
|
| 84 |
+
for name, module in model.named_modules():
|
| 85 |
+
if hasattr(module, 'use_flash_attention'):
|
| 86 |
+
module.use_flash_attention = True
|
| 87 |
+
# Replace forward with flash attention version
|
| 88 |
+
original_forward = module.forward
|
| 89 |
+
|
| 90 |
+
def flash_forward(self, x, *args, **kwargs):
|
| 91 |
+
return self._flash_attention_forward(x, *args, **kwargs)
|
| 92 |
+
|
| 93 |
+
module.forward = flash_forward.__get__(module, type(module))
|
| 94 |
+
|
| 95 |
+
return model
|
| 96 |
+
|
| 97 |
+
except ImportError:
|
| 98 |
+
print("Flash Attention not available. Install with: pip install flash-attn")
|
| 99 |
+
return model
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _apply_int8_quantization(model: nn.Module) -> nn.Module:
|
| 103 |
+
"""
|
| 104 |
+
Apply INT8 quantization using bitsandbytes.
|
| 105 |
+
"""
|
| 106 |
+
try:
|
| 107 |
+
import bitsandbytes as bnb
|
| 108 |
+
|
| 109 |
+
# Replace linear layers with 8-bit variants
|
| 110 |
+
for name, module in model.named_modules():
|
| 111 |
+
if isinstance(module, nn.Linear):
|
| 112 |
+
# Create 8-bit linear replacement
|
| 113 |
+
parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
|
| 114 |
+
child_name = name.rsplit('.', 1)[1] if '.' in name else name
|
| 115 |
+
|
| 116 |
+
# Get parent module
|
| 117 |
+
parent = model
|
| 118 |
+
if parent_name:
|
| 119 |
+
for part in parent_name.split('.'):
|
| 120 |
+
parent = getattr(parent, part)
|
| 121 |
+
|
| 122 |
+
# Replace with 8-bit linear
|
| 123 |
+
replacement = bnb.nn.Linear8bitLt(
|
| 124 |
+
module.in_features,
|
| 125 |
+
module.out_features,
|
| 126 |
+
bias=module.bias is not None,
|
| 127 |
+
has_fp16_weights=False,
|
| 128 |
+
)
|
| 129 |
+
# Copy weights (will be quantized)
|
| 130 |
+
replacement.weight.data = module.weight.data
|
| 131 |
+
if module.bias is not None:
|
| 132 |
+
replacement.bias.data = module.bias.data
|
| 133 |
+
|
| 134 |
+
setattr(parent, child_name, replacement)
|
| 135 |
+
|
| 136 |
+
return model
|
| 137 |
+
|
| 138 |
+
except ImportError:
|
| 139 |
+
print("bitsandbytes not available. Install with: pip install bitsandbytes")
|
| 140 |
+
return model
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _apply_int4_quantization(model: nn.Module) -> nn.Module:
|
| 144 |
+
"""
|
| 145 |
+
Apply INT4 quantization using bitsandbytes.
|
| 146 |
+
More aggressive, for 13B on 8GB VRAM.
|
| 147 |
+
"""
|
| 148 |
+
try:
|
| 149 |
+
import bitsandbytes as bnb
|
| 150 |
+
|
| 151 |
+
for name, module in model.named_modules():
|
| 152 |
+
if isinstance(module, nn.Linear):
|
| 153 |
+
parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
|
| 154 |
+
child_name = name.rsplit('.', 1)[1] if '.' in name else name
|
| 155 |
+
|
| 156 |
+
parent = model
|
| 157 |
+
if parent_name:
|
| 158 |
+
for part in parent_name.split('.'):
|
| 159 |
+
parent = getattr(parent, part)
|
| 160 |
+
|
| 161 |
+
# 4-bit linear
|
| 162 |
+
replacement = bnb.nn.Linear4bit(
|
| 163 |
+
module.in_features,
|
| 164 |
+
module.out_features,
|
| 165 |
+
bias=module.bias is not None,
|
| 166 |
+
compute_dtype=torch.float16,
|
| 167 |
+
compress_statistics=True,
|
| 168 |
+
)
|
| 169 |
+
replacement.weight.data = module.weight.data
|
| 170 |
+
if module.bias is not None:
|
| 171 |
+
replacement.bias.data = module.bias.data
|
| 172 |
+
|
| 173 |
+
setattr(parent, child_name, replacement)
|
| 174 |
+
|
| 175 |
+
return model
|
| 176 |
+
|
| 177 |
+
except ImportError:
|
| 178 |
+
print("bitsandbytes not available.")
|
| 179 |
+
return model
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def get_cuda_memory_usage() -> Dict[str, float]:
|
| 183 |
+
"""Get current CUDA memory usage in GB."""
|
| 184 |
+
if not torch.cuda.is_available():
|
| 185 |
+
return {"error": "CUDA not available"}
|
| 186 |
+
|
| 187 |
+
allocated = torch.cuda.memory_allocated() / 1e9
|
| 188 |
+
reserved = torch.cuda.memory_reserved() / 1e9
|
| 189 |
+
max_allocated = torch.cuda.max_memory_allocated() / 1e9
|
| 190 |
+
|
| 191 |
+
return {
|
| 192 |
+
"allocated_gb": allocated,
|
| 193 |
+
"reserved_gb": reserved,
|
| 194 |
+
"max_allocated_gb": max_allocated,
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def profile_model(
|
| 199 |
+
model: nn.Module,
|
| 200 |
+
input_ids: torch.Tensor,
|
| 201 |
+
num_warmup: int = 10,
|
| 202 |
+
num_runs: int = 100,
|
| 203 |
+
) -> Dict[str, float]:
|
| 204 |
+
"""
|
| 205 |
+
Profile model performance.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
model: Model to profile
|
| 209 |
+
input_ids: Example input
|
| 210 |
+
num_warmup: Number of warmup runs
|
| 211 |
+
num_runs: Number of profiling runs
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Dictionary with timing statistics
|
| 215 |
+
"""
|
| 216 |
+
model.eval()
|
| 217 |
+
device = next(model.parameters()).device
|
| 218 |
+
input_ids = input_ids.to(device)
|
| 219 |
+
|
| 220 |
+
# Warmup
|
| 221 |
+
with torch.no_grad():
|
| 222 |
+
for _ in range(num_warmup):
|
| 223 |
+
_ = model(input_ids)
|
| 224 |
+
|
| 225 |
+
# Profile
|
| 226 |
+
torch.cuda.synchronize()
|
| 227 |
+
import time
|
| 228 |
+
start = time.time()
|
| 229 |
+
|
| 230 |
+
with torch.no_grad():
|
| 231 |
+
for _ in range(num_runs):
|
| 232 |
+
_ = model(input_ids)
|
| 233 |
+
|
| 234 |
+
torch.cuda.synchronize()
|
| 235 |
+
elapsed = time.time() - start
|
| 236 |
+
|
| 237 |
+
avg_time = elapsed / num_runs
|
| 238 |
+
tokens_per_sec = input_ids.shape[1] / avg_time
|
| 239 |
+
|
| 240 |
+
return {
|
| 241 |
+
"avg_time_sec": avg_time,
|
| 242 |
+
"tokens_per_sec": tokens_per_sec,
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def test_cuda_optimize():
|
| 247 |
+
"""Test CUDA optimizations."""
|
| 248 |
+
if not torch.cuda.is_available():
|
| 249 |
+
print("CUDA not available, skipping test")
|
| 250 |
+
return
|
| 251 |
+
|
| 252 |
+
from models.vortex_model import VortexModel
|
| 253 |
+
from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| 254 |
+
|
| 255 |
+
config = VORTEX_7B_CONFIG.copy()
|
| 256 |
+
config["d_model"] = 512
|
| 257 |
+
config["num_layers"] = 2
|
| 258 |
+
config["num_heads"] = 8
|
| 259 |
+
config["vocab_size"] = 1000
|
| 260 |
+
|
| 261 |
+
model = VortexModel(config)
|
| 262 |
+
print(f"Model parameters: {model.get_num_params():,}")
|
| 263 |
+
|
| 264 |
+
# Optimize
|
| 265 |
+
model = optimize_for_cuda(
|
| 266 |
+
model,
|
| 267 |
+
config,
|
| 268 |
+
use_flash_attention=False, # May not be available
|
| 269 |
+
use_torch_compile=False, # Skip compile for test
|
| 270 |
+
quantization=None,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# Test forward
|
| 274 |
+
batch_size = 2
|
| 275 |
+
seq_len = 128
|
| 276 |
+
input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)).cuda()
|
| 277 |
+
|
| 278 |
+
with torch.no_grad():
|
| 279 |
+
output = model(input_ids)
|
| 280 |
+
logits = output["logits"]
|
| 281 |
+
|
| 282 |
+
print(f"Output shape: {logits.shape}")
|
| 283 |
+
print("CUDA optimize test passed!")
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
if __name__ == "__main__":
|
| 287 |
+
test_cuda_optimize()
|
data/__pycache__/deduplication.cpython-313.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
data/__pycache__/domain_classifier.cpython-313.pyc
ADDED
|
Binary file (6.38 kB). View file
|
|
|
data/__pycache__/quality_filter.cpython-313.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
data/dataset_loader.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DatasetLoader: Loads and processes open scientific datasets.
|
| 3 |
+
Supports streaming from HuggingFace datasets with sharding.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
from typing import List, Dict, Optional, Iterator
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from datasets import load_dataset, Dataset, IterableDataset
|
| 13 |
+
import pyarrow.parquet as pq
|
| 14 |
+
except ImportError:
|
| 15 |
+
print("Please install datasets and pyarrow: pip install datasets pyarrow")
|
| 16 |
+
raise
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class VortexDatasetLoader:
|
| 20 |
+
"""
|
| 21 |
+
Loads and processes open scientific datasets.
|
| 22 |
+
Supports streaming with sharding to Parquet files.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# Open datasets configuration
|
| 26 |
+
DATASETS = {
|
| 27 |
+
"pile_scientific": {
|
| 28 |
+
"path": "EleutherAI/pile",
|
| 29 |
+
"subset": "pubmed_central",
|
| 30 |
+
"split": "train",
|
| 31 |
+
"text_field": "text",
|
| 32 |
+
"domain": "biology", # approximate
|
| 33 |
+
},
|
| 34 |
+
"s2orc": {
|
| 35 |
+
"path": "allenai/s2orc",
|
| 36 |
+
"subset": None,
|
| 37 |
+
"split": "train",
|
| 38 |
+
"text_field": "text",
|
| 39 |
+
"domain": "multidisciplinary",
|
| 40 |
+
},
|
| 41 |
+
"pes2o": {
|
| 42 |
+
"path": "allenai/peS2o",
|
| 43 |
+
"subset": None,
|
| 44 |
+
"split": "train",
|
| 45 |
+
"text_field": "text",
|
| 46 |
+
"domain": "multidisciplinary",
|
| 47 |
+
},
|
| 48 |
+
"automath": {
|
| 49 |
+
"path": "math-ai/AutoMathText",
|
| 50 |
+
"subset": None,
|
| 51 |
+
"split": "train",
|
| 52 |
+
"text_field": "text",
|
| 53 |
+
"domain": "math",
|
| 54 |
+
},
|
| 55 |
+
"deepmind_math": {
|
| 56 |
+
"path": "deepmind/math_dataset",
|
| 57 |
+
"subset": "algebra__linear_1d",
|
| 58 |
+
"split": "train",
|
| 59 |
+
"text_field": "text",
|
| 60 |
+
"domain": "math",
|
| 61 |
+
},
|
| 62 |
+
"pubmed_qa": {
|
| 63 |
+
"path": "bigbio/pubmed_qa",
|
| 64 |
+
"subset": "pubmed_qa_labeled_fold0_source",
|
| 65 |
+
"split": "train",
|
| 66 |
+
"text_field": "question",
|
| 67 |
+
"domain": "biology",
|
| 68 |
+
},
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
cache_dir: str = "./data/cache",
|
| 74 |
+
output_dir: str = "./data/processed",
|
| 75 |
+
num_proc: int = 4,
|
| 76 |
+
):
|
| 77 |
+
"""
|
| 78 |
+
Initialize dataset loader.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
cache_dir: Directory for caching downloaded datasets
|
| 82 |
+
output_dir: Directory for processed shards
|
| 83 |
+
num_proc: Number of processes for data processing
|
| 84 |
+
"""
|
| 85 |
+
self.cache_dir = Path(cache_dir)
|
| 86 |
+
self.output_dir = Path(output_dir)
|
| 87 |
+
self.num_proc = num_proc
|
| 88 |
+
|
| 89 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 90 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 91 |
+
|
| 92 |
+
def load_dataset(
|
| 93 |
+
self,
|
| 94 |
+
dataset_name: str,
|
| 95 |
+
streaming: bool = True,
|
| 96 |
+
max_samples: Optional[int] = None,
|
| 97 |
+
) -> Iterator[Dict]:
|
| 98 |
+
"""
|
| 99 |
+
Load a dataset as an iterator.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
dataset_name: Name from DATASETS config
|
| 103 |
+
streaming: Use streaming mode for large datasets
|
| 104 |
+
max_samples: Maximum number of samples to yield
|
| 105 |
+
|
| 106 |
+
Yields:
|
| 107 |
+
Dictionary with text and metadata
|
| 108 |
+
"""
|
| 109 |
+
if dataset_name not in self.DATASETS:
|
| 110 |
+
raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(self.DATASETS.keys())}")
|
| 111 |
+
|
| 112 |
+
config = self.DATASETS[dataset_name]
|
| 113 |
+
|
| 114 |
+
print(f"Loading dataset: {dataset_name}")
|
| 115 |
+
print(f" Path: {config['path']}")
|
| 116 |
+
print(f" Streaming: {streaming}")
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
dataset = load_dataset(
|
| 120 |
+
config["path"],
|
| 121 |
+
name=config["subset"],
|
| 122 |
+
split=config["split"],
|
| 123 |
+
streaming=streaming,
|
| 124 |
+
cache_dir=str(self.cache_dir),
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
count = 0
|
| 128 |
+
for sample in dataset:
|
| 129 |
+
text = sample.get(config["text_field"], "")
|
| 130 |
+
if not text or not isinstance(text, str):
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
yield {
|
| 134 |
+
"text": text,
|
| 135 |
+
"dataset": dataset_name,
|
| 136 |
+
"domain": config["domain"],
|
| 137 |
+
"source": config["path"],
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
count += 1
|
| 141 |
+
if max_samples and count >= max_samples:
|
| 142 |
+
break
|
| 143 |
+
|
| 144 |
+
print(f"Loaded {count} samples from {dataset_name}")
|
| 145 |
+
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"Error loading dataset {dataset_name}: {e}")
|
| 148 |
+
# Return empty iterator
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
def load_multiple_datasets(
|
| 152 |
+
self,
|
| 153 |
+
dataset_names: List[str],
|
| 154 |
+
streaming: bool = True,
|
| 155 |
+
max_per_dataset: Optional[int] = None,
|
| 156 |
+
) -> Iterator[Dict]:
|
| 157 |
+
"""
|
| 158 |
+
Load multiple datasets and yield samples interleaved.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
dataset_names: List of dataset names
|
| 162 |
+
streaming: Use streaming mode
|
| 163 |
+
max_per_dataset: Max samples per dataset
|
| 164 |
+
|
| 165 |
+
Yields:
|
| 166 |
+
Dictionary with text and metadata
|
| 167 |
+
"""
|
| 168 |
+
iterators = []
|
| 169 |
+
for name in dataset_names:
|
| 170 |
+
it = self.load_dataset(name, streaming=streaming, max_samples=max_per_dataset)
|
| 171 |
+
iterators.append(it)
|
| 172 |
+
|
| 173 |
+
# Simple round-robin interleaving
|
| 174 |
+
active = len(iterators)
|
| 175 |
+
while active > 0:
|
| 176 |
+
for i, it in enumerate(iterators):
|
| 177 |
+
if it is None:
|
| 178 |
+
continue
|
| 179 |
+
try:
|
| 180 |
+
yield next(it)
|
| 181 |
+
except StopIteration:
|
| 182 |
+
iterators[i] = None
|
| 183 |
+
active -= 1
|
| 184 |
+
break
|
| 185 |
+
|
| 186 |
+
def shard_to_parquet(
|
| 187 |
+
self,
|
| 188 |
+
samples: Iterator[Dict],
|
| 189 |
+
output_prefix: str,
|
| 190 |
+
samples_per_shard: int = 10000,
|
| 191 |
+
):
|
| 192 |
+
"""
|
| 193 |
+
Write samples to sharded Parquet files.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
samples: Iterator of sample dictionaries
|
| 197 |
+
output_prefix: Prefix for output files (e.g., "train")
|
| 198 |
+
samples_per_shard: Number of samples per shard
|
| 199 |
+
"""
|
| 200 |
+
shard_index = 0
|
| 201 |
+
buffer = []
|
| 202 |
+
|
| 203 |
+
for sample in samples:
|
| 204 |
+
buffer.append(sample)
|
| 205 |
+
|
| 206 |
+
if len(buffer) >= samples_per_shard:
|
| 207 |
+
self._write_shard(buffer, output_prefix, shard_index)
|
| 208 |
+
shard_index += 1
|
| 209 |
+
buffer = []
|
| 210 |
+
|
| 211 |
+
# Write remaining
|
| 212 |
+
if buffer:
|
| 213 |
+
self._write_shard(buffer, output_prefix, shard_index)
|
| 214 |
+
|
| 215 |
+
print(f"Wrote {shard_index + 1} shards to {self.output_dir}")
|
| 216 |
+
|
| 217 |
+
def _write_shard(
|
| 218 |
+
self,
|
| 219 |
+
buffer: List[Dict],
|
| 220 |
+
output_prefix: str,
|
| 221 |
+
shard_index: int,
|
| 222 |
+
):
|
| 223 |
+
"""Write a single shard to Parquet."""
|
| 224 |
+
import pandas as pd
|
| 225 |
+
|
| 226 |
+
df = pd.DataFrame(buffer)
|
| 227 |
+
output_path = self.output_dir / f"{output_prefix}_{shard_index:05d}.parquet"
|
| 228 |
+
df.to_parquet(output_path, index=False)
|
| 229 |
+
|
| 230 |
+
def get_shard_list(
|
| 231 |
+
self,
|
| 232 |
+
prefix: str,
|
| 233 |
+
) -> List[Path]:
|
| 234 |
+
"""Get list of shard files matching prefix."""
|
| 235 |
+
return sorted(self.output_dir.glob(f"{prefix}_*.parquet"))
|
| 236 |
+
|
| 237 |
+
def read_shard(
|
| 238 |
+
self,
|
| 239 |
+
shard_path: Path,
|
| 240 |
+
) -> List[Dict]:
|
| 241 |
+
"""Read a single shard."""
|
| 242 |
+
import pandas as pd
|
| 243 |
+
df = pd.read_parquet(shard_path)
|
| 244 |
+
return df.to_dict('records')
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def test_dataset_loader():
|
| 248 |
+
"""Test the dataset loader."""
|
| 249 |
+
loader = VortexDatasetLoader()
|
| 250 |
+
|
| 251 |
+
# Test loading a small dataset
|
| 252 |
+
print("Testing dataset loader...")
|
| 253 |
+
count = 0
|
| 254 |
+
for sample in loader.load_dataset("pubmed_qa", streaming=True, max_samples=10):
|
| 255 |
+
print(f"Sample {count}: {sample['text'][:100]}...")
|
| 256 |
+
count += 1
|
| 257 |
+
|
| 258 |
+
print(f"Loaded {count} samples")
|
| 259 |
+
print("DatasetLoader test passed!")
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
if __name__ == "__main__":
|
| 263 |
+
test_dataset_loader()
|
data/deduplication.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Deduplication: MinHash LSH for near-duplicate detection.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import hashlib
|
| 6 |
+
import random
|
| 7 |
+
from typing import List, Set, Tuple, Optional
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class MinHashSignature:
|
| 13 |
+
"""MinHash signature for a document."""
|
| 14 |
+
hash_values: List[int]
|
| 15 |
+
doc_id: str
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MinHashLSH:
|
| 19 |
+
"""
|
| 20 |
+
MinHash LSH for near-duplicate detection.
|
| 21 |
+
Uses shingling and MinHash to estimate Jaccard similarity.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
num_permutations: int = 128,
|
| 27 |
+
threshold: float = 0.8,
|
| 28 |
+
bands: int = 16,
|
| 29 |
+
rows_per_band: int = 8,
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Initialize MinHash LSH.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
num_permutations: Number of hash permutations for MinHash
|
| 36 |
+
threshold: Similarity threshold for considering duplicates
|
| 37 |
+
bands: Number of bands for LSH
|
| 38 |
+
rows_per_band: Rows per band (bands * rows_per_band = num_permutations)
|
| 39 |
+
"""
|
| 40 |
+
self.num_permutations = num_permutations
|
| 41 |
+
self.threshold = threshold
|
| 42 |
+
self.bands = bands
|
| 43 |
+
self.rows_per_band = rows_per_band
|
| 44 |
+
|
| 45 |
+
assert bands * rows_per_band == num_permutations
|
| 46 |
+
|
| 47 |
+
# Generate random hash functions
|
| 48 |
+
self.hash_functions = self._generate_hash_functions(num_permutations)
|
| 49 |
+
|
| 50 |
+
# LSH index: band_id -> {bucket_hash -> set of doc_ids}
|
| 51 |
+
self.index = [dict() for _ in range(bands)]
|
| 52 |
+
|
| 53 |
+
# Store signatures for similarity computation
|
| 54 |
+
self.signatures = {} # doc_id -> MinHashSignature
|
| 55 |
+
|
| 56 |
+
def _generate_hash_functions(self, n: int) -> List:
|
| 57 |
+
"""Generate n random hash functions."""
|
| 58 |
+
# Use random permutations of large prime
|
| 59 |
+
functions = []
|
| 60 |
+
for _ in range(n):
|
| 61 |
+
a = random.randint(1, 2**32 - 1)
|
| 62 |
+
b = random.randint(0, 2**32 - 1)
|
| 63 |
+
functions.append((a, b))
|
| 64 |
+
return functions
|
| 65 |
+
|
| 66 |
+
def _hash(self, x: int, a: int, b: int) -> int:
|
| 67 |
+
"""Universal hash function: (a*x + b) mod prime."""
|
| 68 |
+
prime = 2**61 - 1
|
| 69 |
+
return ((a * x + b) % prime) & 0xFFFFFFFF
|
| 70 |
+
|
| 71 |
+
def _compute_minhash(self, shingles: Set[int]) -> List[int]:
|
| 72 |
+
"""
|
| 73 |
+
Compute MinHash signature for a set of shingles.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
shingles: Set of shingle hash values
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
List of minhash values (one per permutation)
|
| 80 |
+
"""
|
| 81 |
+
signature = []
|
| 82 |
+
for a, b in self.hash_functions:
|
| 83 |
+
min_hash = min(self._hash(shingle, a, b) for shingle in shingles)
|
| 84 |
+
signature.append(min_hash)
|
| 85 |
+
return signature
|
| 86 |
+
|
| 87 |
+
def _shingle_text(
|
| 88 |
+
self,
|
| 89 |
+
text: str,
|
| 90 |
+
k: int = 5,
|
| 91 |
+
) -> Set[int]:
|
| 92 |
+
"""
|
| 93 |
+
Extract k-gram shingles from text.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
text: Input text
|
| 97 |
+
k: Shingle size (characters)
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Set of shingle hashes
|
| 101 |
+
"""
|
| 102 |
+
text = text.lower()
|
| 103 |
+
shingles = set()
|
| 104 |
+
for i in range(len(text) - k + 1):
|
| 105 |
+
shingle = text[i:i+k]
|
| 106 |
+
# Hash shingle
|
| 107 |
+
shingle_hash = int(hashlib.md5(shingle.encode()).hexdigest()[:8], 16)
|
| 108 |
+
shingles.add(shingle_hash)
|
| 109 |
+
return shingles
|
| 110 |
+
|
| 111 |
+
def add_document(
|
| 112 |
+
self,
|
| 113 |
+
doc_id: str,
|
| 114 |
+
text: str,
|
| 115 |
+
compute_signature: bool = True,
|
| 116 |
+
) -> MinHashSignature:
|
| 117 |
+
"""
|
| 118 |
+
Add a document to the LSH index.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
doc_id: Unique document ID
|
| 122 |
+
text: Document text
|
| 123 |
+
compute_signature: Whether to compute signature (or use precomputed)
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
MinHash signature
|
| 127 |
+
"""
|
| 128 |
+
if compute_signature:
|
| 129 |
+
shingles = self._shingle_text(text)
|
| 130 |
+
signature = self._compute_minhash(shingles)
|
| 131 |
+
else:
|
| 132 |
+
raise ValueError("Must compute signature")
|
| 133 |
+
|
| 134 |
+
# Store signature
|
| 135 |
+
signature_obj = MinHashSignature(hash_values=signature, doc_id=doc_id)
|
| 136 |
+
self.signatures[doc_id] = signature_obj
|
| 137 |
+
|
| 138 |
+
# Index into bands
|
| 139 |
+
for band_idx in range(self.bands):
|
| 140 |
+
start = band_idx * self.rows_per_band
|
| 141 |
+
end = start + self.rows_per_band
|
| 142 |
+
band_signature = tuple(signature[start:end])
|
| 143 |
+
bucket_hash = hash(band_signature)
|
| 144 |
+
|
| 145 |
+
if bucket_hash not in self.index[band_idx]:
|
| 146 |
+
self.index[band_idx][bucket_hash] = set()
|
| 147 |
+
self.index[band_idx][bucket_hash].add(doc_id)
|
| 148 |
+
|
| 149 |
+
return signature_obj
|
| 150 |
+
|
| 151 |
+
def query(
|
| 152 |
+
self,
|
| 153 |
+
text: str,
|
| 154 |
+
candidate_doc_ids: Optional[Set[str]] = None,
|
| 155 |
+
) -> List[Tuple[str, float]]:
|
| 156 |
+
"""
|
| 157 |
+
Query for near-duplicate documents.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
text: Query text
|
| 161 |
+
candidate_doc_ids: Optional set of candidate doc IDs to check
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
List of (doc_id, similarity) above threshold
|
| 165 |
+
"""
|
| 166 |
+
shingles = self._shingle_text(text)
|
| 167 |
+
query_signature = self._compute_minhash(shingles)
|
| 168 |
+
|
| 169 |
+
# Find candidates via LSH
|
| 170 |
+
candidate_sets = []
|
| 171 |
+
for band_idx in range(self.bands):
|
| 172 |
+
start = band_idx * self.rows_per_band
|
| 173 |
+
end = start + self.rows_per_band
|
| 174 |
+
band_signature = tuple(query_signature[start:end])
|
| 175 |
+
bucket_hash = hash(band_signature)
|
| 176 |
+
|
| 177 |
+
if bucket_hash in self.index[band_idx]:
|
| 178 |
+
candidate_sets.append(self.index[band_idx][bucket_hash])
|
| 179 |
+
|
| 180 |
+
# Union of candidates
|
| 181 |
+
candidates = set()
|
| 182 |
+
for s in candidate_sets:
|
| 183 |
+
candidates.update(s)
|
| 184 |
+
|
| 185 |
+
if candidate_doc_ids is not None:
|
| 186 |
+
candidates = candidates.intersection(candidate_doc_ids)
|
| 187 |
+
|
| 188 |
+
# Compute exact Jaccard similarity for candidates
|
| 189 |
+
results = []
|
| 190 |
+
query_shingles = shingles
|
| 191 |
+
for doc_id in candidates:
|
| 192 |
+
# In practice, would retrieve stored shingles
|
| 193 |
+
# For now, approximate using signature
|
| 194 |
+
similarity = self._estimate_similarity(query_signature, doc_id)
|
| 195 |
+
if similarity >= self.threshold:
|
| 196 |
+
results.append((doc_id, similarity))
|
| 197 |
+
|
| 198 |
+
return sorted(results, key=lambda x: x[1], reverse=True)
|
| 199 |
+
|
| 200 |
+
def _estimate_similarity(
|
| 201 |
+
self,
|
| 202 |
+
signature1: List[int],
|
| 203 |
+
doc_id2: str,
|
| 204 |
+
) -> float:
|
| 205 |
+
"""
|
| 206 |
+
Estimate Jaccard similarity between two signatures.
|
| 207 |
+
Uses MinHash similarity: proportion of matching hash values.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
signature1: First MinHash signature
|
| 211 |
+
doc_id2: Second document ID (signature retrieved from storage)
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Estimated Jaccard similarity
|
| 215 |
+
"""
|
| 216 |
+
if doc_id2 not in self.signatures:
|
| 217 |
+
return 0.0
|
| 218 |
+
|
| 219 |
+
signature2 = self.signatures[doc_id2].hash_values
|
| 220 |
+
|
| 221 |
+
# Count matching values
|
| 222 |
+
matches = sum(1 for h1, h2 in zip(signature1, signature2) if h1 == h2)
|
| 223 |
+
similarity = matches / len(signature1)
|
| 224 |
+
|
| 225 |
+
return similarity
|
| 226 |
+
|
| 227 |
+
def compute_signature(self, text: str) -> MinHashSignature:
|
| 228 |
+
"""Compute MinHash signature for text."""
|
| 229 |
+
shingles = self._shingle_text(text)
|
| 230 |
+
signature = self._compute_minhash(shingles)
|
| 231 |
+
return MinHashSignature(hash_values=signature, doc_id="")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def test_deduplication():
|
| 235 |
+
"""Test MinHash LSH."""
|
| 236 |
+
lsh = MinHashLSH(num_permutations=64, threshold=0.7, bands=8, rows_per_band=8)
|
| 237 |
+
|
| 238 |
+
# Add documents
|
| 239 |
+
docs = [
|
| 240 |
+
("doc1", "The quick brown fox jumps over the lazy dog."),
|
| 241 |
+
("doc2", "The quick brown fox jumps over the lazy dog!!!"), # near duplicate
|
| 242 |
+
("doc3", "The quick brown fox leaps over the lazy dog."), # near duplicate
|
| 243 |
+
("doc4", "Completely unrelated text about science and experiments."),
|
| 244 |
+
]
|
| 245 |
+
|
| 246 |
+
signatures = {}
|
| 247 |
+
for doc_id, text in docs:
|
| 248 |
+
sig = lsh.add_document(doc_id, text)
|
| 249 |
+
signatures[doc_id] = sig
|
| 250 |
+
|
| 251 |
+
# Query with doc1
|
| 252 |
+
results = lsh.query(docs[0][1])
|
| 253 |
+
print(f"Query results for doc1: {results}")
|
| 254 |
+
# Should find doc2 and doc3 as similar
|
| 255 |
+
|
| 256 |
+
print("Deduplication test passed!")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
if __name__ == "__main__":
|
| 260 |
+
test_deduplication()
|
data/domain_classifier.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DomainClassifier: Classifies documents into 7 science domains.
|
| 3 |
+
Uses a simple linear classifier on top of text features.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
from typing import List, Tuple, Optional
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DomainClassifier(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Classifies documents into 7 science domains:
|
| 15 |
+
0: Physics
|
| 16 |
+
1: Mathematics
|
| 17 |
+
2: Chemistry
|
| 18 |
+
3: Biology
|
| 19 |
+
4: Earth Science
|
| 20 |
+
5: Space Science
|
| 21 |
+
6: Zoology
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
# Domain keywords for rule-based fallback
|
| 25 |
+
DOMAIN_KEYWORDS = {
|
| 26 |
+
0: ['physics', 'quantum', 'relativity', 'mechanics', 'thermodynamics', 'electromagnetism'],
|
| 27 |
+
1: ['mathematics', 'algebra', 'calculus', 'geometry', 'topology', 'proof', 'theorem'],
|
| 28 |
+
2: ['chemistry', 'molecular', 'reaction', 'compound', 'element', 'organic'],
|
| 29 |
+
3: ['biology', 'cell', 'gene', 'protein', 'organism', 'evolution'],
|
| 30 |
+
4: ['earth', 'geology', 'climate', 'ocean', 'atmosphere', 'meteorology'],
|
| 31 |
+
5: ['space', 'astronomy', 'planet', 'star', 'galaxy', 'cosmology'],
|
| 32 |
+
6: ['zoology', 'animal', 'species', 'vertebrate', 'invertebrate', 'ecology'],
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
def __init__(self, d_model: int, num_domains: int = 7):
|
| 36 |
+
"""
|
| 37 |
+
Initialize domain classifier.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
d_model: Input embedding dimension
|
| 41 |
+
num_domains: Number of domains (7)
|
| 42 |
+
"""
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.d_model = d_model
|
| 45 |
+
self.num_domains = num_domains
|
| 46 |
+
|
| 47 |
+
# Simple linear classifier
|
| 48 |
+
self.classifier = nn.Linear(d_model, num_domains)
|
| 49 |
+
|
| 50 |
+
# Initialize weights
|
| 51 |
+
nn.init.normal_(self.classifier.weight, mean=0.0, std=0.02)
|
| 52 |
+
nn.init.zeros_(self.classifier.bias)
|
| 53 |
+
|
| 54 |
+
def forward(
|
| 55 |
+
self,
|
| 56 |
+
hidden_states: torch.Tensor,
|
| 57 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 58 |
+
) -> torch.Tensor:
|
| 59 |
+
"""
|
| 60 |
+
Classify domain from hidden states.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
hidden_states: (batch, seq_len, d_model)
|
| 64 |
+
attention_mask: (batch, seq_len)
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Domain logits (batch, num_domains)
|
| 68 |
+
"""
|
| 69 |
+
# Mean pooling over sequence (masked)
|
| 70 |
+
if attention_mask is not None:
|
| 71 |
+
mask = attention_mask.unsqueeze(-1) # (batch, seq_len, 1)
|
| 72 |
+
summed = (hidden_states * mask).sum(dim=1)
|
| 73 |
+
counts = mask.sum(dim=1)
|
| 74 |
+
pooled = summed / counts.clamp(min=1)
|
| 75 |
+
else:
|
| 76 |
+
pooled = hidden_states.mean(dim=1)
|
| 77 |
+
|
| 78 |
+
# Classify
|
| 79 |
+
logits = self.classifier(pooled)
|
| 80 |
+
return logits
|
| 81 |
+
|
| 82 |
+
def classify_text(
|
| 83 |
+
self,
|
| 84 |
+
text: str,
|
| 85 |
+
) -> Tuple[int, float]:
|
| 86 |
+
"""
|
| 87 |
+
Rule-based fallback classification from raw text.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
text: Input text string
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
(domain_id, confidence)
|
| 94 |
+
"""
|
| 95 |
+
text_lower = text.lower()
|
| 96 |
+
|
| 97 |
+
# Count keyword matches per domain
|
| 98 |
+
scores = []
|
| 99 |
+
for domain_id, keywords in self.DOMAIN_KEYWORDS.items():
|
| 100 |
+
score = sum(1 for kw in keywords if kw in text_lower)
|
| 101 |
+
scores.append(score)
|
| 102 |
+
|
| 103 |
+
if max(scores) == 0:
|
| 104 |
+
return 0, 0.0 # Unknown -> default to physics
|
| 105 |
+
|
| 106 |
+
best_domain = scores.index(max(scores))
|
| 107 |
+
confidence = max(scores) / sum(scores) if sum(scores) > 0 else 0.0
|
| 108 |
+
|
| 109 |
+
return best_domain, confidence
|
| 110 |
+
|
| 111 |
+
def compute_loss(
|
| 112 |
+
self,
|
| 113 |
+
logits: torch.Tensor,
|
| 114 |
+
domain_labels: torch.Tensor,
|
| 115 |
+
) -> torch.Tensor:
|
| 116 |
+
"""
|
| 117 |
+
Compute classification loss.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
logits: (batch, num_domains)
|
| 121 |
+
domain_labels: (batch,) with domain IDs
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Cross-entropy loss
|
| 125 |
+
"""
|
| 126 |
+
return nn.functional.cross_entropy(logits, domain_labels)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def test_domain_classifier():
|
| 130 |
+
"""Test DomainClassifier."""
|
| 131 |
+
d_model = 512
|
| 132 |
+
batch_size = 4
|
| 133 |
+
seq_len = 128
|
| 134 |
+
|
| 135 |
+
classifier = DomainClassifier(d_model)
|
| 136 |
+
|
| 137 |
+
# Test with random hidden states
|
| 138 |
+
hidden = torch.randn(batch_size, seq_len, d_model)
|
| 139 |
+
logits = classifier(hidden)
|
| 140 |
+
print(f"Logits shape: {logits.shape}")
|
| 141 |
+
assert logits.shape == (batch_size, 7)
|
| 142 |
+
|
| 143 |
+
# Test with text
|
| 144 |
+
texts = [
|
| 145 |
+
"The quantum mechanics of particles...",
|
| 146 |
+
"Solving differential equations...",
|
| 147 |
+
"Chemical reactions produce compounds...",
|
| 148 |
+
"Cells contain DNA and proteins...",
|
| 149 |
+
]
|
| 150 |
+
for text in texts:
|
| 151 |
+
domain, conf = classifier.classify_text(text)
|
| 152 |
+
print(f"Text: {text[:30]}... -> Domain {domain}, conf {conf:.2f}")
|
| 153 |
+
|
| 154 |
+
# Test loss
|
| 155 |
+
labels = torch.tensor([0, 1, 2, 3])
|
| 156 |
+
loss = classifier.compute_loss(logits, labels)
|
| 157 |
+
print(f"Loss: {loss.item():.4f}")
|
| 158 |
+
|
| 159 |
+
print("DomainClassifier test passed!")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
test_domain_classifier()
|
data/quality_filter.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ScienceQualityFilter: Multi-stage quality filtering for scientific text.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
from typing import List, Tuple, Optional
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class FilterStats:
|
| 12 |
+
"""Statistics from quality filtering."""
|
| 13 |
+
total: int = 0
|
| 14 |
+
passed: int = 0
|
| 15 |
+
failed_length: int = 0
|
| 16 |
+
failed_language: int = 0
|
| 17 |
+
failed_content: int = 0
|
| 18 |
+
failed_equations: int = 0
|
| 19 |
+
failed_repetition: int = 0
|
| 20 |
+
failed_citations: int = 0
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ScienceQualityFilter:
|
| 24 |
+
"""
|
| 25 |
+
Multi-stage quality filtering for scientific text.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
min_length: int = 128,
|
| 31 |
+
max_length: int = 64000,
|
| 32 |
+
max_repetition_ratio: float = 0.2,
|
| 33 |
+
min_equation_ratio: float = 0.0,
|
| 34 |
+
max_equation_ratio: float = 0.6,
|
| 35 |
+
min_citation_ratio: float = 0.0,
|
| 36 |
+
max_citation_ratio: float = 0.4,
|
| 37 |
+
):
|
| 38 |
+
"""
|
| 39 |
+
Initialize quality filter.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
min_length: Minimum text length in characters
|
| 43 |
+
max_length: Maximum text length in characters
|
| 44 |
+
max_repetition_ratio: Maximum character-level repetition ratio
|
| 45 |
+
min_equation_ratio: Minimum equation density (optional)
|
| 46 |
+
max_equation_ratio: Maximum equation density
|
| 47 |
+
min_citation_ratio: Minimum citation density (optional)
|
| 48 |
+
max_citation_ratio: Maximum citation density
|
| 49 |
+
"""
|
| 50 |
+
self.min_length = min_length
|
| 51 |
+
self.max_length = max_length
|
| 52 |
+
self.max_repetition_ratio = max_repetition_ratio
|
| 53 |
+
self.min_equation_ratio = min_equation_ratio
|
| 54 |
+
self.max_equation_ratio = max_equation_ratio
|
| 55 |
+
self.min_citation_ratio = min_citation_ratio
|
| 56 |
+
self.max_citation_ratio = max_citation_ratio
|
| 57 |
+
|
| 58 |
+
def filter(self, text: str, stats: Optional[FilterStats] = None) -> bool:
|
| 59 |
+
"""
|
| 60 |
+
Run all quality checks on text.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
text: Input text
|
| 64 |
+
stats: Optional stats object to update
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
True if text passes all checks, False otherwise
|
| 68 |
+
"""
|
| 69 |
+
if stats is None:
|
| 70 |
+
stats = FilterStats()
|
| 71 |
+
stats.total += 1
|
| 72 |
+
|
| 73 |
+
# 1. Length check
|
| 74 |
+
if not self.length_check(text):
|
| 75 |
+
stats.failed_length += 1
|
| 76 |
+
return False
|
| 77 |
+
|
| 78 |
+
# 2. Language check (English only, simplified)
|
| 79 |
+
if not self.language_check(text):
|
| 80 |
+
stats.failed_language += 1
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
# 3. Science content check
|
| 84 |
+
if not self.science_content_check(text):
|
| 85 |
+
stats.failed_content += 1
|
| 86 |
+
return False
|
| 87 |
+
|
| 88 |
+
# 4. Equation validity check
|
| 89 |
+
if not self.equation_validity_check(text):
|
| 90 |
+
stats.failed_equations += 1
|
| 91 |
+
return False
|
| 92 |
+
|
| 93 |
+
# 5. Repetition check
|
| 94 |
+
if not self.repetition_check(text):
|
| 95 |
+
stats.failed_repetition += 1
|
| 96 |
+
return False
|
| 97 |
+
|
| 98 |
+
# 6. Citation ratio check
|
| 99 |
+
if not self.citation_ratio_check(text):
|
| 100 |
+
stats.failed_citations += 1
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
stats.passed += 1
|
| 104 |
+
return True
|
| 105 |
+
|
| 106 |
+
def length_check(self, text: str) -> bool:
|
| 107 |
+
"""Check text length."""
|
| 108 |
+
length = len(text)
|
| 109 |
+
return self.min_length <= length <= self.max_length
|
| 110 |
+
|
| 111 |
+
def language_check(self, text: str) -> bool:
|
| 112 |
+
"""
|
| 113 |
+
Check if text is likely English.
|
| 114 |
+
Simplified heuristic: count common English words.
|
| 115 |
+
"""
|
| 116 |
+
# Common English words
|
| 117 |
+
english_words = {'the', 'be', 'to', 'of', 'and', 'a', 'in', 'that', 'have', 'i'}
|
| 118 |
+
words = re.findall(r'\b[a-zA-Z]{2,}\b', text.lower())
|
| 119 |
+
if len(words) < 10:
|
| 120 |
+
return False
|
| 121 |
+
|
| 122 |
+
english_count = sum(1 for w in words if w in english_words)
|
| 123 |
+
return english_count / len(words) > 0.1
|
| 124 |
+
|
| 125 |
+
def science_content_check(self, text: str) -> bool:
|
| 126 |
+
"""
|
| 127 |
+
Check if text contains scientific content.
|
| 128 |
+
Looks for scientific terms, equations, units, etc.
|
| 129 |
+
"""
|
| 130 |
+
# Scientific keywords
|
| 131 |
+
sci_keywords = [
|
| 132 |
+
'experiment', 'data', 'result', 'analysis', 'method',
|
| 133 |
+
'theory', 'hypothesis', 'conclusion', 'discussion',
|
| 134 |
+
'figure', 'table', 'equation', 'reference', 'citation',
|
| 135 |
+
'molecular', 'protein', 'gene', 'cell', 'reaction',
|
| 136 |
+
'mathematical', 'derivation', 'proof', 'theorem',
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
# Units
|
| 140 |
+
units = ['m', 'kg', 's', 'mol', 'K', 'J', 'N', 'Pa', 'Hz', 'eV', '°C']
|
| 141 |
+
|
| 142 |
+
# Count occurrences
|
| 143 |
+
text_lower = text.lower()
|
| 144 |
+
keyword_count = sum(1 for kw in sci_keywords if kw in text_lower)
|
| 145 |
+
unit_count = sum(1 for u in units if f' {u}' in text_lower or f'{u} ' in text_lower)
|
| 146 |
+
|
| 147 |
+
return keyword_count >= 2 or unit_count >= 1
|
| 148 |
+
|
| 149 |
+
def equation_validity_check(self, text: str) -> bool:
|
| 150 |
+
"""
|
| 151 |
+
Check if LaTeX equations are well-formed.
|
| 152 |
+
"""
|
| 153 |
+
# Count dollar signs
|
| 154 |
+
dollar_count = text.count('$')
|
| 155 |
+
if dollar_count % 2 != 0:
|
| 156 |
+
return False # Unmatched dollar signs
|
| 157 |
+
|
| 158 |
+
# Count backslash brackets
|
| 159 |
+
lbracket = text.count('\\[')
|
| 160 |
+
rbracket = text.count('\\]')
|
| 161 |
+
if lbracket != rbracket:
|
| 162 |
+
return False
|
| 163 |
+
|
| 164 |
+
# Count parentheses in inline math
|
| 165 |
+
lparen = text.count('\\(')
|
| 166 |
+
rparen = text.count('\\)')
|
| 167 |
+
if lparen != rparen:
|
| 168 |
+
return False
|
| 169 |
+
|
| 170 |
+
# Check for common LaTeX errors
|
| 171 |
+
# Unmatched braces
|
| 172 |
+
brace_balance = 0
|
| 173 |
+
for char in text:
|
| 174 |
+
if char == '{':
|
| 175 |
+
brace_balance += 1
|
| 176 |
+
elif char == '}':
|
| 177 |
+
brace_balance -= 1
|
| 178 |
+
if brace_balance < 0:
|
| 179 |
+
return False # Closing without opening
|
| 180 |
+
|
| 181 |
+
if brace_balance != 0:
|
| 182 |
+
return False # Unmatched braces
|
| 183 |
+
|
| 184 |
+
return True
|
| 185 |
+
|
| 186 |
+
def repetition_check(self, text: str) -> bool:
|
| 187 |
+
"""
|
| 188 |
+
Check for excessive repetition.
|
| 189 |
+
Uses character-level 4-gram repetition.
|
| 190 |
+
"""
|
| 191 |
+
if len(text) < 100:
|
| 192 |
+
return True
|
| 193 |
+
|
| 194 |
+
# Get 4-grams
|
| 195 |
+
n = 4
|
| 196 |
+
ngrams = [text[i:i+n] for i in range(len(text) - n + 1)]
|
| 197 |
+
|
| 198 |
+
# Count repetitions
|
| 199 |
+
from collections import Counter
|
| 200 |
+
ngram_counts = Counter(ngrams)
|
| 201 |
+
total_ngrams = len(ngrams)
|
| 202 |
+
unique_ngrams = len(ngram_counts)
|
| 203 |
+
|
| 204 |
+
if total_ngrams == 0:
|
| 205 |
+
return True
|
| 206 |
+
|
| 207 |
+
repetition_ratio = 1 - (unique_ngrams / total_ngrams)
|
| 208 |
+
return repetition_ratio <= self.max_repetition_ratio
|
| 209 |
+
|
| 210 |
+
def citation_ratio_check(self, text: str) -> bool:
|
| 211 |
+
"""
|
| 212 |
+
Check if citation density is reasonable.
|
| 213 |
+
"""
|
| 214 |
+
# Count citation patterns
|
| 215 |
+
# (Author, Year)
|
| 216 |
+
inline1 = len(re.findall(r'\([A-Za-z\s]+,?\s*\d{4}\)', text))
|
| 217 |
+
# [1] or [1-3]
|
| 218 |
+
inline2 = len(re.findall(r'\[\d+(?:[-,]\d+)*\]', text))
|
| 219 |
+
# [Author, Year]
|
| 220 |
+
inline3 = len(re.findall(r'\[[A-Za-z\s]+,?\s*\d{4}\]', text))
|
| 221 |
+
|
| 222 |
+
total_citations = inline1 + inline2 + inline3
|
| 223 |
+
|
| 224 |
+
# Estimate word count
|
| 225 |
+
words = re.findall(r'\b[a-zA-Z]{2,}\b', text)
|
| 226 |
+
if len(words) == 0:
|
| 227 |
+
return True
|
| 228 |
+
|
| 229 |
+
citation_ratio = total_citations / len(words)
|
| 230 |
+
|
| 231 |
+
# Allow range
|
| 232 |
+
return self.min_citation_ratio <= citation_ratio <= self.max_citation_ratio
|
| 233 |
+
|
| 234 |
+
def get_stats(self, stats: FilterStats) -> str:
|
| 235 |
+
"""Get formatted statistics string."""
|
| 236 |
+
total = stats.total if stats.total > 0 else 1
|
| 237 |
+
return (
|
| 238 |
+
f"Quality filter stats:\n"
|
| 239 |
+
f" Total: {stats.total}\n"
|
| 240 |
+
f" Passed: {stats.passed} ({stats.passed/total*100:.1f}%)\n"
|
| 241 |
+
f" Failed - Length: {stats.failed_length}\n"
|
| 242 |
+
f" Failed - Language: {stats.failed_language}\n"
|
| 243 |
+
f" Failed - Content: {stats.failed_content}\n"
|
| 244 |
+
f" Failed - Equations: {stats.failed_equations}\n"
|
| 245 |
+
f" Failed - Repetition: {stats.failed_repetition}\n"
|
| 246 |
+
f" Failed - Citations: {stats.failed_citations}"
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def test_quality_filter():
|
| 251 |
+
"""Test the quality filter."""
|
| 252 |
+
filter = ScienceQualityFilter()
|
| 253 |
+
|
| 254 |
+
# Good sample
|
| 255 |
+
good_text = """
|
| 256 |
+
The experiment was conducted to test the hypothesis. We collected data from
|
| 257 |
+
100 participants and performed statistical analysis. The results show a
|
| 258 |
+
significant effect (p < 0.05). According to Smith et al., this confirms
|
| 259 |
+
the theoretical prediction. The equation E = mc^2 is fundamental.
|
| 260 |
+
"""
|
| 261 |
+
print(f"Good text passes: {filter.filter(good_text)}")
|
| 262 |
+
|
| 263 |
+
# Bad sample (too short)
|
| 264 |
+
short_text = "This is too short."
|
| 265 |
+
print(f"Short text passes: {filter.filter(short_text)}")
|
| 266 |
+
|
| 267 |
+
# Bad sample (unmatched equations)
|
| 268 |
+
bad_eq = "Here is an equation $E = mc^2 and another $F = ma."
|
| 269 |
+
print(f"Unmatched $ passes: {filter.filter(bad_eq)}")
|
| 270 |
+
|
| 271 |
+
# Bad sample (excessive repetition)
|
| 272 |
+
repetitive = "test test test test test test test test test test " * 100
|
| 273 |
+
print(f"Repetitive passes: {filter.filter(repetitive)}")
|
| 274 |
+
|
| 275 |
+
print("QualityFilter test passed!")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
if __name__ == "__main__":
|
| 279 |
+
test_quality_filter()
|
data/scraper.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VortexScienceScraper: Scrapes scientific content from open access sources.
|
| 3 |
+
Respects robots.txt and rate limits.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import time
|
| 7 |
+
import requests
|
| 8 |
+
from typing import List, Dict, Optional
|
| 9 |
+
from urllib.robotparser import RobotFileParser
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import json
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class VortexScienceScraper:
|
| 15 |
+
"""
|
| 16 |
+
Scrapes scientific content from open access sources.
|
| 17 |
+
Sources: arXiv, PubMed Central, Wikipedia, NIST, NASA.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
SOURCES = {
|
| 21 |
+
"arxiv": {
|
| 22 |
+
"base_url": "https://arxiv.org",
|
| 23 |
+
"search_url": "https://arxiv.org/search/",
|
| 24 |
+
"rate_limit": 1.0, # seconds between requests
|
| 25 |
+
"robots": "https://arxiv.org/robots.txt",
|
| 26 |
+
},
|
| 27 |
+
"pubmed": {
|
| 28 |
+
"base_url": "https://www.ncbi.nlm.nih.gov/pmc",
|
| 29 |
+
"search_url": "https://www.ncbi.nlm.nih.gov/pmc/articles/",
|
| 30 |
+
"rate_limit": 0.5,
|
| 31 |
+
"robots": "https://www.ncbi.nlm.nih.gov/robots.txt",
|
| 32 |
+
},
|
| 33 |
+
"wikipedia": {
|
| 34 |
+
"base_url": "https://en.wikipedia.org",
|
| 35 |
+
"search_url": "https://en.wikipedia.org/w/api.php",
|
| 36 |
+
"rate_limit": 0.1,
|
| 37 |
+
"robots": "https://en.wikipedia.org/robots.txt",
|
| 38 |
+
},
|
| 39 |
+
"nist": {
|
| 40 |
+
"base_url": "https://webbook.nist.gov",
|
| 41 |
+
"search_url": "https://webbook.nist.gov/cgi/cbook.cgi",
|
| 42 |
+
"rate_limit": 1.0,
|
| 43 |
+
"robots": "https://webbook.nist.gov/robots.txt",
|
| 44 |
+
},
|
| 45 |
+
"nasa": {
|
| 46 |
+
"base_url": "https://ntrs.nasa.gov",
|
| 47 |
+
"search_url": "https://ntrs.nasa.gov/api/citations/search",
|
| 48 |
+
"rate_limit": 1.0,
|
| 49 |
+
"robots": "https://ntrs.nasa.gov/robots.txt",
|
| 50 |
+
},
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
output_dir: str = "./data/scraped",
|
| 56 |
+
respect_robots: bool = True,
|
| 57 |
+
user_agent: str = "VortexScientificBot/1.0",
|
| 58 |
+
):
|
| 59 |
+
"""
|
| 60 |
+
Initialize scraper.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
output_dir: Directory to save scraped data
|
| 64 |
+
respect_robots: Whether to respect robots.txt
|
| 65 |
+
user_agent: User agent string for requests
|
| 66 |
+
"""
|
| 67 |
+
self.output_dir = Path(output_dir)
|
| 68 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 69 |
+
self.respect_robots = respect_robots
|
| 70 |
+
self.user_agent = user_agent
|
| 71 |
+
|
| 72 |
+
self.session = requests.Session()
|
| 73 |
+
self.session.headers.update({"User-Agent": user_agent})
|
| 74 |
+
|
| 75 |
+
# Cache for robots.txt
|
| 76 |
+
self.robots_cache = {}
|
| 77 |
+
|
| 78 |
+
# Rate limit tracking
|
| 79 |
+
self.last_request_time = {}
|
| 80 |
+
|
| 81 |
+
def _check_robots_allowed(self, url: str) -> bool:
|
| 82 |
+
"""Check if robots.txt allows scraping the URL."""
|
| 83 |
+
if not self.respect_robots:
|
| 84 |
+
return True
|
| 85 |
+
|
| 86 |
+
# Extract base domain
|
| 87 |
+
from urllib.parse import urlparse
|
| 88 |
+
parsed = urlparse(url)
|
| 89 |
+
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
| 90 |
+
|
| 91 |
+
if base_url not in self.robots_cache:
|
| 92 |
+
rp = RobotFileParser()
|
| 93 |
+
rp.set_url(base_url + "/robots.txt")
|
| 94 |
+
try:
|
| 95 |
+
rp.read()
|
| 96 |
+
self.robots_cache[base_url] = rp
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print(f"Could not read robots.txt for {base_url}: {e}")
|
| 99 |
+
return False
|
| 100 |
+
|
| 101 |
+
rp = self.robots_cache[base_url]
|
| 102 |
+
return rp.can_fetch(self.user_agent, url)
|
| 103 |
+
|
| 104 |
+
def _rate_limit(self, source: str):
|
| 105 |
+
"""Enforce rate limiting for a source."""
|
| 106 |
+
now = time.time()
|
| 107 |
+
last = self.last_request_time.get(source, 0)
|
| 108 |
+
delay = self.SOURCES[source]["rate_limit"]
|
| 109 |
+
if now - last < delay:
|
| 110 |
+
time.sleep(delay - (now - last))
|
| 111 |
+
self.last_request_time[source] = time.time()
|
| 112 |
+
|
| 113 |
+
def scrape_arxiv(
|
| 114 |
+
self,
|
| 115 |
+
query: str,
|
| 116 |
+
max_results: int = 100,
|
| 117 |
+
categories: Optional[List[str]] = None,
|
| 118 |
+
) -> List[Dict]:
|
| 119 |
+
"""
|
| 120 |
+
Scrape arXiv papers.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
query: Search query
|
| 124 |
+
max_results: Maximum number of results
|
| 125 |
+
categories: Optional list of arXiv categories (e.g., ['physics', 'math'])
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
List of paper metadata and abstracts
|
| 129 |
+
"""
|
| 130 |
+
papers = []
|
| 131 |
+
|
| 132 |
+
params = {
|
| 133 |
+
"query": query,
|
| 134 |
+
"searchtype": "all",
|
| 135 |
+
"abstracts": "show",
|
| 136 |
+
"size": min(max_results, 200), # arXiv max per page
|
| 137 |
+
"order": "-announced_date_first",
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
if categories:
|
| 141 |
+
params["filter"] = "categories:" + "+OR+".join(categories)
|
| 142 |
+
|
| 143 |
+
url = self.SOURCES["arxiv"]["search_url"]
|
| 144 |
+
|
| 145 |
+
if not self._check_robots_allowed(url):
|
| 146 |
+
print(f"Robots.txt disallows scraping {url}")
|
| 147 |
+
return papers
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
self._rate_limit("arxiv")
|
| 151 |
+
response = self.session.get(url, params=params)
|
| 152 |
+
response.raise_for_status()
|
| 153 |
+
|
| 154 |
+
# Parse HTML (simplified - would use BeautifulSoup in practice)
|
| 155 |
+
# For now, return placeholder
|
| 156 |
+
print(f"Scraped arXiv query '{query}' - got response status {response.status_code}")
|
| 157 |
+
|
| 158 |
+
# Placeholder: would extract paper titles, abstracts, PDF links
|
| 159 |
+
for i in range(min(10, max_results)):
|
| 160 |
+
papers.append({
|
| 161 |
+
"source": "arxiv",
|
| 162 |
+
"title": f"Paper {i}",
|
| 163 |
+
"abstract": "Abstract placeholder...",
|
| 164 |
+
"pdf_url": f"https://arxiv.org/pdf/{i}.pdf",
|
| 165 |
+
})
|
| 166 |
+
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(f"Error scraping arXiv: {e}")
|
| 169 |
+
|
| 170 |
+
return papers
|
| 171 |
+
|
| 172 |
+
def scrape_pubmed(
|
| 173 |
+
self,
|
| 174 |
+
query: str,
|
| 175 |
+
max_results: int = 100,
|
| 176 |
+
) -> List[Dict]:
|
| 177 |
+
"""Scrape PubMed Central articles."""
|
| 178 |
+
articles = []
|
| 179 |
+
|
| 180 |
+
# PubMed API endpoint
|
| 181 |
+
url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
|
| 182 |
+
params = {
|
| 183 |
+
"db": "pmc",
|
| 184 |
+
"term": query,
|
| 185 |
+
"retmax": max_results,
|
| 186 |
+
"retmode": "json",
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
if not self._check_robots_allowed(url):
|
| 190 |
+
print(f"Robots.txt disallows {url}")
|
| 191 |
+
return articles
|
| 192 |
+
|
| 193 |
+
try:
|
| 194 |
+
self._rate_limit("pubmed")
|
| 195 |
+
response = self.session.get(url, params=params)
|
| 196 |
+
response.raise_for_status()
|
| 197 |
+
|
| 198 |
+
data = response.json()
|
| 199 |
+
pmc_ids = data.get("esearchresult", {}).get("idlist", [])
|
| 200 |
+
|
| 201 |
+
for pmc_id in pmc_ids[:10]: # Limit for demo
|
| 202 |
+
articles.append({
|
| 203 |
+
"source": "pubmed",
|
| 204 |
+
"pmc_id": pmc_id,
|
| 205 |
+
"url": f"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC{pmc_id}/",
|
| 206 |
+
})
|
| 207 |
+
|
| 208 |
+
print(f"Found {len(pmc_ids)} PubMed articles")
|
| 209 |
+
|
| 210 |
+
except Exception as e:
|
| 211 |
+
print(f"Error scraping PubMed: {e}")
|
| 212 |
+
|
| 213 |
+
return articles
|
| 214 |
+
|
| 215 |
+
def scrape_wikipedia(
|
| 216 |
+
self,
|
| 217 |
+
topic: str,
|
| 218 |
+
max_pages: int = 10,
|
| 219 |
+
) -> List[Dict]:
|
| 220 |
+
"""Scrape Wikipedia science articles."""
|
| 221 |
+
pages = []
|
| 222 |
+
|
| 223 |
+
# Wikipedia API
|
| 224 |
+
url = "https://en.wikipedia.org/w/api.php"
|
| 225 |
+
params = {
|
| 226 |
+
"action": "query",
|
| 227 |
+
"format": "json",
|
| 228 |
+
"prop": "extracts",
|
| 229 |
+
"exintro": True,
|
| 230 |
+
"titles": topic,
|
| 231 |
+
"redirects": True,
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
if not self._check_robots_allowed(url):
|
| 235 |
+
print(f"Robots.txt disallows {url}")
|
| 236 |
+
return pages
|
| 237 |
+
|
| 238 |
+
try:
|
| 239 |
+
self._rate_limit("wikipedia")
|
| 240 |
+
response = self.session.get(url, params=params)
|
| 241 |
+
response.raise_for_status()
|
| 242 |
+
|
| 243 |
+
data = response.json()
|
| 244 |
+
pages_data = data.get("query", {}).get("pages", {})
|
| 245 |
+
|
| 246 |
+
for page_id, page in pages_data.items():
|
| 247 |
+
if "extract" in page:
|
| 248 |
+
pages.append({
|
| 249 |
+
"source": "wikipedia",
|
| 250 |
+
"title": page.get("title", ""),
|
| 251 |
+
"text": page.get("extract", ""),
|
| 252 |
+
})
|
| 253 |
+
|
| 254 |
+
except Exception as e:
|
| 255 |
+
print(f"Error scraping Wikipedia: {e}")
|
| 256 |
+
|
| 257 |
+
return pages
|
| 258 |
+
|
| 259 |
+
def scrape_nist(
|
| 260 |
+
self,
|
| 261 |
+
element: str,
|
| 262 |
+
) -> List[Dict]:
|
| 263 |
+
"""Scrape NIST chemistry webbook for element data."""
|
| 264 |
+
data = []
|
| 265 |
+
|
| 266 |
+
url = "https://webbook.nist.gov/cgi/cbook.cgi"
|
| 267 |
+
params = {
|
| 268 |
+
"Formula": element,
|
| 269 |
+
"Units": "SI",
|
| 270 |
+
"Submit": "Submit",
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
if not self._check_robots_allowed(url):
|
| 274 |
+
print(f"Robots.txt disallows {url}")
|
| 275 |
+
return data
|
| 276 |
+
|
| 277 |
+
try:
|
| 278 |
+
self._rate_limit("nist")
|
| 279 |
+
response = self.session.get(url, params=params)
|
| 280 |
+
response.raise_for_status()
|
| 281 |
+
|
| 282 |
+
# Placeholder - would parse HTML tables
|
| 283 |
+
data.append({
|
| 284 |
+
"source": "nist",
|
| 285 |
+
"element": element,
|
| 286 |
+
"html": response.text[:1000],
|
| 287 |
+
})
|
| 288 |
+
|
| 289 |
+
except Exception as e:
|
| 290 |
+
print(f"Error scraping NIST: {e}")
|
| 291 |
+
|
| 292 |
+
return data
|
| 293 |
+
|
| 294 |
+
def scrape_nasa(
|
| 295 |
+
self,
|
| 296 |
+
query: str,
|
| 297 |
+
max_results: int = 50,
|
| 298 |
+
) -> List[Dict]:
|
| 299 |
+
"""Scrape NASA technical reports."""
|
| 300 |
+
reports = []
|
| 301 |
+
|
| 302 |
+
url = "https://ntrs.nasa.gov/api/citations/search"
|
| 303 |
+
params = {
|
| 304 |
+
"q": query,
|
| 305 |
+
"page[size]": max_results,
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
if not self._check_robots_allowed(url):
|
| 309 |
+
print(f"Robots.txt disallows {url}")
|
| 310 |
+
return reports
|
| 311 |
+
|
| 312 |
+
try:
|
| 313 |
+
self._rate_limit("nasa")
|
| 314 |
+
response = self.session.get(url, params=params)
|
| 315 |
+
response.raise_for_status()
|
| 316 |
+
|
| 317 |
+
data = response.json()
|
| 318 |
+
for item in data.get("data", [])[:10]:
|
| 319 |
+
reports.append({
|
| 320 |
+
"source": "nasa",
|
| 321 |
+
"title": item.get("attributes", {}).get("title", ""),
|
| 322 |
+
"abstract": item.get("attributes", {}).get("abstract", ""),
|
| 323 |
+
"download_url": item.get("attributes", {}).get("downloads", {}).get("pdf", ""),
|
| 324 |
+
})
|
| 325 |
+
|
| 326 |
+
except Exception as e:
|
| 327 |
+
print(f"Error scraping NASA: {e}")
|
| 328 |
+
|
| 329 |
+
return reports
|
| 330 |
+
|
| 331 |
+
def save_results(
|
| 332 |
+
self,
|
| 333 |
+
results: List[Dict],
|
| 334 |
+
filename: str,
|
| 335 |
+
):
|
| 336 |
+
"""Save scraped results to JSON."""
|
| 337 |
+
output_path = self.output_dir / filename
|
| 338 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 339 |
+
json.dump(results, f, indent=2, ensure_ascii=False)
|
| 340 |
+
print(f"Saved {len(results)} results to {output_path}")
|
| 341 |
+
|
| 342 |
+
def scrape_all_sources(
|
| 343 |
+
self,
|
| 344 |
+
queries: Dict[str, str],
|
| 345 |
+
max_per_source: int = 50,
|
| 346 |
+
) -> Dict[str, List[Dict]]:
|
| 347 |
+
"""
|
| 348 |
+
Scrape all sources with given queries.
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
queries: Dict mapping source name to query string
|
| 352 |
+
max_per_source: Max results per source
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
Dict mapping source to list of results
|
| 356 |
+
"""
|
| 357 |
+
all_results = {}
|
| 358 |
+
|
| 359 |
+
for source, query in queries.items():
|
| 360 |
+
if source not in self.SOURCES:
|
| 361 |
+
print(f"Unknown source: {source}")
|
| 362 |
+
continue
|
| 363 |
+
|
| 364 |
+
print(f"Scraping {source} with query: {query}")
|
| 365 |
+
|
| 366 |
+
if source == "arxiv":
|
| 367 |
+
results = self.scrape_arxiv(query, max_results=max_per_source)
|
| 368 |
+
elif source == "pubmed":
|
| 369 |
+
results = self.scrape_pubmed(query, max_results=max_per_source)
|
| 370 |
+
elif source == "wikipedia":
|
| 371 |
+
results = self.scrape_wikipedia(query, max_pages=max_per_source)
|
| 372 |
+
elif source == "nist":
|
| 373 |
+
results = self.scrape_nist(query)
|
| 374 |
+
elif source == "nasa":
|
| 375 |
+
results = self.scrape_nasa(query, max_results=max_per_source)
|
| 376 |
+
else:
|
| 377 |
+
results = []
|
| 378 |
+
|
| 379 |
+
all_results[source] = results
|
| 380 |
+
|
| 381 |
+
# Save intermediate results
|
| 382 |
+
self.save_results(results, f"{source}_results.json")
|
| 383 |
+
|
| 384 |
+
return all_results
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def test_scraper():
|
| 388 |
+
"""Test the scraper (limited)."""
|
| 389 |
+
scraper = VortexScienceScraper()
|
| 390 |
+
|
| 391 |
+
# Test Wikipedia (lightweight)
|
| 392 |
+
print("Testing Wikipedia scrape...")
|
| 393 |
+
results = scraper.scrape_wikipedia("quantum mechanics", max_pages=2)
|
| 394 |
+
print(f"Got {len(results)} Wikipedia pages")
|
| 395 |
+
|
| 396 |
+
# Test arXiv (rate limited)
|
| 397 |
+
print("Testing arXiv scrape...")
|
| 398 |
+
results = scraper.scrape_arxiv("quantum", max_results=5)
|
| 399 |
+
print(f"Got {len(results)} arXiv papers")
|
| 400 |
+
|
| 401 |
+
print("Scraper test passed!")
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
if __name__ == "__main__":
|
| 405 |
+
test_scraper()
|
inference.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Inference script for Vortex models.
|
| 4 |
+
Supports both CUDA and MPS backends.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| 14 |
+
from configs.vortex_13b_config import VORTEX_13B_CONFIG
|
| 15 |
+
|
| 16 |
+
from models.vortex_model import VortexModel
|
| 17 |
+
from tokenizer.vortex_tokenizer import VortexScienceTokenizer
|
| 18 |
+
from inference.cuda_optimize import optimize_for_cuda, profile_model
|
| 19 |
+
from inference.mps_optimize import optimize_for_mps, profile_model_mps
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def parse_args():
|
| 23 |
+
parser = argparse.ArgumentParser(description="Run inference with Vortex model")
|
| 24 |
+
parser.add_argument("--model_path", type=str, required=True,
|
| 25 |
+
help="Path to trained model checkpoint")
|
| 26 |
+
parser.add_argument("--config", type=str, default=None,
|
| 27 |
+
help="Path to model config (if not in checkpoint)")
|
| 28 |
+
parser.add_argument("--tokenizer_path", type=str, default=None,
|
| 29 |
+
help="Path to tokenizer")
|
| 30 |
+
parser.add_argument("--model_size", type=str, choices=["7b", "13b"], default="7b",
|
| 31 |
+
help="Model size for config")
|
| 32 |
+
parser.add_argument("--device", type=str, default="cuda",
|
| 33 |
+
choices=["cuda", "mps", "cpu"],
|
| 34 |
+
help="Device to run on")
|
| 35 |
+
parser.add_argument("--use_mps", action="store_true",
|
| 36 |
+
help="Use MPS backend (Apple Silicon)")
|
| 37 |
+
parser.add_argument("--quantization", type=str, choices=[None, "int8", "int4"], default=None,
|
| 38 |
+
help="Apply quantization (CUDA only)")
|
| 39 |
+
parser.add_argument("--flash_attention", action="store_true",
|
| 40 |
+
help="Use Flash Attention 2 (CUDA only)")
|
| 41 |
+
parser.add_argument("--torch_compile", action="store_true",
|
| 42 |
+
help="Use torch.compile")
|
| 43 |
+
parser.add_argument("--prompt", type=str, default=None,
|
| 44 |
+
help="Input prompt for generation")
|
| 45 |
+
parser.add_argument("--interactive", action="store_true",
|
| 46 |
+
help="Run in interactive mode")
|
| 47 |
+
parser.add_argument("--max_new_tokens", type=int, default=100,
|
| 48 |
+
help="Maximum new tokens to generate")
|
| 49 |
+
parser.add_argument("--temperature", type=float, default=0.8,
|
| 50 |
+
help="Sampling temperature")
|
| 51 |
+
parser.add_argument("--top_p", type=float, default=0.9,
|
| 52 |
+
help="Top-p sampling")
|
| 53 |
+
parser.add_argument("--profile", action="store_true",
|
| 54 |
+
help="Profile performance")
|
| 55 |
+
return parser.parse_args()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def load_model(args):
|
| 59 |
+
"""Load model with appropriate optimizations."""
|
| 60 |
+
# Load config
|
| 61 |
+
if args.config:
|
| 62 |
+
from configuration_vortex import VortexConfig
|
| 63 |
+
config = VortexConfig.from_pretrained(args.config)
|
| 64 |
+
else:
|
| 65 |
+
# Use default config for size
|
| 66 |
+
if args.model_size == "7b":
|
| 67 |
+
config_dict = VORTEX_7B_CONFIG
|
| 68 |
+
else:
|
| 69 |
+
config_dict = VORTEX_13B_CONFIG
|
| 70 |
+
from configuration_vortex import VortexConfig
|
| 71 |
+
config = VortexConfig(**config_dict)
|
| 72 |
+
|
| 73 |
+
# Create model
|
| 74 |
+
print("Creating model...")
|
| 75 |
+
model = VortexModel(config.to_dict())
|
| 76 |
+
|
| 77 |
+
# Load checkpoint
|
| 78 |
+
print(f"Loading checkpoint from {args.model_path}")
|
| 79 |
+
checkpoint = torch.load(args.model_path, map_location="cpu", weights_only=False)
|
| 80 |
+
if "model_state_dict" in checkpoint:
|
| 81 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 82 |
+
else:
|
| 83 |
+
model.load_state_dict(checkpoint)
|
| 84 |
+
print("Model loaded")
|
| 85 |
+
|
| 86 |
+
# Apply optimizations
|
| 87 |
+
device = torch.device(args.device)
|
| 88 |
+
if args.use_mps or args.device == "mps":
|
| 89 |
+
print("Optimizing for MPS...")
|
| 90 |
+
model = optimize_for_mps(model, config.to_dict(), use_sdpa=True)
|
| 91 |
+
else:
|
| 92 |
+
print("Optimizing for CUDA...")
|
| 93 |
+
model = optimize_for_cuda(
|
| 94 |
+
model,
|
| 95 |
+
config.to_dict(),
|
| 96 |
+
use_flash_attention=args.flash_attention,
|
| 97 |
+
use_torch_compile=args.torch_compile,
|
| 98 |
+
quantization=args.quantization,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
model = model.to(device)
|
| 102 |
+
model.eval()
|
| 103 |
+
|
| 104 |
+
return model, config
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def load_tokenizer(args):
|
| 108 |
+
"""Load tokenizer."""
|
| 109 |
+
tokenizer_path = args.tokenizer_path
|
| 110 |
+
if not tokenizer_path:
|
| 111 |
+
# Try to find in model directory
|
| 112 |
+
model_dir = Path(args.model_path).parent
|
| 113 |
+
tokenizer_path = model_dir / "vortex_tokenizer.json"
|
| 114 |
+
|
| 115 |
+
if tokenizer_path and Path(tokenizer_path).exists():
|
| 116 |
+
from tokenization_vortex import VortexTokenizer
|
| 117 |
+
tokenizer = VortexTokenizer.from_pretrained(str(model_dir))
|
| 118 |
+
else:
|
| 119 |
+
print("Warning: No tokenizer found, using dummy tokenizer")
|
| 120 |
+
class DummyTokenizer:
|
| 121 |
+
def __call__(self, text, **kwargs):
|
| 122 |
+
return {"input_ids": torch.tensor([[1, 2, 3]])}
|
| 123 |
+
def decode(self, ids, **kwargs):
|
| 124 |
+
return "dummy"
|
| 125 |
+
tokenizer = DummyTokenizer()
|
| 126 |
+
|
| 127 |
+
return tokenizer
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def generate_text(model, tokenizer, prompt, args):
|
| 131 |
+
"""Generate text from prompt."""
|
| 132 |
+
# Tokenize
|
| 133 |
+
inputs = tokenizer(
|
| 134 |
+
prompt,
|
| 135 |
+
return_tensors="pt",
|
| 136 |
+
padding=False,
|
| 137 |
+
truncation=True,
|
| 138 |
+
max_length=model.config.max_seq_len - args.max_new_tokens,
|
| 139 |
+
)
|
| 140 |
+
input_ids = inputs["input_ids"].to(next(model.parameters()).device)
|
| 141 |
+
|
| 142 |
+
# Generate
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
if hasattr(model, 'generate'):
|
| 145 |
+
output_ids = model.generate(
|
| 146 |
+
input_ids,
|
| 147 |
+
max_new_tokens=args.max_new_tokens,
|
| 148 |
+
temperature=args.temperature,
|
| 149 |
+
top_p=args.top_p,
|
| 150 |
+
do_sample=True,
|
| 151 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
# Manual generation
|
| 155 |
+
for _ in range(args.max_new_tokens):
|
| 156 |
+
outputs = model(input_ids)
|
| 157 |
+
next_token_logits = outputs["logits"][:, -1, :]
|
| 158 |
+
next_token = torch.multinomial(
|
| 159 |
+
torch.softmax(next_token_logits / args.temperature, dim=-1),
|
| 160 |
+
num_samples=1,
|
| 161 |
+
)
|
| 162 |
+
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
| 163 |
+
|
| 164 |
+
# Check for EOS
|
| 165 |
+
if next_token.item() == tokenizer.eos_token_id:
|
| 166 |
+
break
|
| 167 |
+
|
| 168 |
+
# Decode
|
| 169 |
+
generated = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True)
|
| 170 |
+
return generated
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def main():
|
| 174 |
+
args = parse_args()
|
| 175 |
+
|
| 176 |
+
# Load model and tokenizer
|
| 177 |
+
model, config = load_model(args)
|
| 178 |
+
tokenizer = load_tokenizer(args)
|
| 179 |
+
|
| 180 |
+
print(f"Model loaded on {next(model.parameters()).device}")
|
| 181 |
+
print(f"Model parameters: {model.get_num_params():,}")
|
| 182 |
+
|
| 183 |
+
# Profile if requested
|
| 184 |
+
if args.profile:
|
| 185 |
+
print("Profiling...")
|
| 186 |
+
dummy_input = torch.randint(0, config.vocab_size, (1, 128)).to(next(model.parameters()).device)
|
| 187 |
+
if args.use_mps or args.device == "mps":
|
| 188 |
+
stats = profile_model_mps(model, dummy_input)
|
| 189 |
+
else:
|
| 190 |
+
stats = profile_model(model, dummy_input)
|
| 191 |
+
print("Profile results:")
|
| 192 |
+
for k, v in stats.items():
|
| 193 |
+
print(f" {k}: {v:.4f}")
|
| 194 |
+
return
|
| 195 |
+
|
| 196 |
+
# Interactive mode
|
| 197 |
+
if args.interactive:
|
| 198 |
+
print("Interactive mode. Type 'quit' to exit.")
|
| 199 |
+
while True:
|
| 200 |
+
prompt = input("\nPrompt: ")
|
| 201 |
+
if prompt.lower() == "quit":
|
| 202 |
+
break
|
| 203 |
+
response = generate_text(model, tokenizer, prompt, args)
|
| 204 |
+
print(f"\nResponse: {response}")
|
| 205 |
+
elif args.prompt:
|
| 206 |
+
response = generate_text(model, tokenizer, args.prompt, args)
|
| 207 |
+
print(f"Response: {response}")
|
| 208 |
+
else:
|
| 209 |
+
print("No prompt provided. Use --prompt or --interactive.")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
main()
|
modeling_vortex.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vortex model implementation for HuggingFace.
|
| 3 |
+
Integrates with transformers library.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Optional, Tuple, List, Dict, Any
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from transformers import PreTrainedModel, PretrainedConfig, GenerationConfig
|
| 10 |
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| 11 |
+
|
| 12 |
+
from configuration_vortex import VortexConfig
|
| 13 |
+
from models.vortex_model import VortexModel
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class VortexPreTrainedModel(PreTrainedModel):
|
| 17 |
+
"""
|
| 18 |
+
Base class for Vortex models.
|
| 19 |
+
Handles loading/saving in HF format.
|
| 20 |
+
"""
|
| 21 |
+
config_class = VortexConfig
|
| 22 |
+
base_model_prefix = "vortex"
|
| 23 |
+
supports_gradient_checkpointing = True
|
| 24 |
+
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
| 25 |
+
|
| 26 |
+
def _init_weights(self, module):
|
| 27 |
+
"""Initialize weights."""
|
| 28 |
+
if isinstance(module, nn.Linear):
|
| 29 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 30 |
+
if module.bias is not None:
|
| 31 |
+
module.bias.data.zero_()
|
| 32 |
+
elif isinstance(module, nn.Embedding):
|
| 33 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 34 |
+
if module.padding_idx is not None:
|
| 35 |
+
module.weight.data[module.padding_idx].zero_()
|
| 36 |
+
elif isinstance(module, nn.LayerNorm):
|
| 37 |
+
module.bias.data.zero_()
|
| 38 |
+
module.weight.data.fill_(1.0)
|
| 39 |
+
|
| 40 |
+
def get_input_embeddings(self):
|
| 41 |
+
return self.vortex.embed_tokens
|
| 42 |
+
|
| 43 |
+
def set_input_embeddings(self, value):
|
| 44 |
+
self.vortex.embed_tokens = value
|
| 45 |
+
|
| 46 |
+
def get_output_embeddings(self):
|
| 47 |
+
return self.vortex.lm_head
|
| 48 |
+
|
| 49 |
+
def set_output_embeddings(self, new_embeddings):
|
| 50 |
+
self.vortex.lm_head = new_embeddings
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class VortexForCausalLM(VortexPreTrainedModel):
|
| 54 |
+
"""
|
| 55 |
+
Vortex model for causal language modeling.
|
| 56 |
+
"""
|
| 57 |
+
_tied_weights_keys = ["vortex.lm_head.weight"]
|
| 58 |
+
|
| 59 |
+
def __init__(self, config: VortexConfig):
|
| 60 |
+
super().__init__(config)
|
| 61 |
+
self.config = config
|
| 62 |
+
|
| 63 |
+
# Build core model
|
| 64 |
+
self.vortex = VortexModel(config.to_dict())
|
| 65 |
+
|
| 66 |
+
# Initialize weights
|
| 67 |
+
self.apply(self._init_weights)
|
| 68 |
+
|
| 69 |
+
# Tie weights if configured
|
| 70 |
+
if self.config.tie_word_embeddings:
|
| 71 |
+
self.tie_weights()
|
| 72 |
+
|
| 73 |
+
def forward(
|
| 74 |
+
self,
|
| 75 |
+
input_ids: torch.LongTensor = None,
|
| 76 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 77 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 78 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 79 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 80 |
+
labels: Optional[torch.LongTensor] = None,
|
| 81 |
+
use_cache: Optional[bool] = None,
|
| 82 |
+
output_attentions: Optional[bool] = None,
|
| 83 |
+
output_hidden_states: Optional[bool] = None,
|
| 84 |
+
return_dict: Optional[bool] = None,
|
| 85 |
+
domain_ids: Optional[torch.LongTensor] = None,
|
| 86 |
+
domain_tags: Optional[torch.Tensor] = None,
|
| 87 |
+
text: Optional[List[str]] = None,
|
| 88 |
+
) -> CausalLMOutputWithCrossAttentions:
|
| 89 |
+
"""
|
| 90 |
+
Forward pass.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
input_ids: Token IDs (batch, seq_len)
|
| 94 |
+
attention_mask: Attention mask (batch, seq_len)
|
| 95 |
+
labels: Labels for LM loss (batch, seq_len)
|
| 96 |
+
domain_ids: Domain IDs (batch,)
|
| 97 |
+
domain_tags: Domain tag masks (batch, seq_len, num_domains)
|
| 98 |
+
text: Original text strings (for science modules)
|
| 99 |
+
"""
|
| 100 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 101 |
+
|
| 102 |
+
# Pass through Vortex model
|
| 103 |
+
outputs = self.vortex(
|
| 104 |
+
input_ids=input_ids,
|
| 105 |
+
attention_mask=attention_mask,
|
| 106 |
+
domain_ids=domain_ids,
|
| 107 |
+
domain_tags=domain_tags,
|
| 108 |
+
text=text,
|
| 109 |
+
return_dict=True,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
logits = outputs["logits"]
|
| 113 |
+
last_hidden_state = outputs["last_hidden_state"]
|
| 114 |
+
|
| 115 |
+
loss = None
|
| 116 |
+
if labels is not None:
|
| 117 |
+
# Compute cross-entropy loss
|
| 118 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 119 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 120 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| 121 |
+
loss = loss_fct(
|
| 122 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 123 |
+
shift_labels.view(-1),
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
if not return_dict:
|
| 127 |
+
output = (logits,) + (last_hidden_state,)
|
| 128 |
+
return (loss,) + output if loss is not None else output
|
| 129 |
+
|
| 130 |
+
return CausalLMOutputWithCrossAttentions(
|
| 131 |
+
loss=loss,
|
| 132 |
+
logits=logits,
|
| 133 |
+
hidden_states=last_hidden_state,
|
| 134 |
+
attentions=None,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def prepare_inputs_for_generation(
|
| 138 |
+
self,
|
| 139 |
+
input_ids,
|
| 140 |
+
past_key_values=None,
|
| 141 |
+
attention_mask=None,
|
| 142 |
+
**kwargs,
|
| 143 |
+
):
|
| 144 |
+
"""Prepare inputs for text generation."""
|
| 145 |
+
# Omit tokens that are already past
|
| 146 |
+
if past_key_values:
|
| 147 |
+
input_ids = input_ids[:, -1:]
|
| 148 |
+
|
| 149 |
+
return {
|
| 150 |
+
"input_ids": input_ids,
|
| 151 |
+
"attention_mask": attention_mask,
|
| 152 |
+
"past_key_values": past_key_values,
|
| 153 |
+
"use_cache": kwargs.get("use_cache", True),
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
def generate(
|
| 157 |
+
self,
|
| 158 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 159 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 160 |
+
**kwargs,
|
| 161 |
+
):
|
| 162 |
+
"""Generate text."""
|
| 163 |
+
from transformers import GenerationConfig
|
| 164 |
+
|
| 165 |
+
generation_config = kwargs.pop("generation_config", None)
|
| 166 |
+
if generation_config is None:
|
| 167 |
+
generation_config = GenerationConfig.from_model_config(self.config)
|
| 168 |
+
|
| 169 |
+
return super().generate(
|
| 170 |
+
input_ids=input_ids,
|
| 171 |
+
inputs_embeds=inputs_embeds,
|
| 172 |
+
generation_config=generation_config,
|
| 173 |
+
**kwargs,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# Register model for AutoModel
|
| 178 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 179 |
+
|
| 180 |
+
AutoConfig.register("vortex", VortexConfig)
|
| 181 |
+
AutoModelForCausalLM.register(VortexConfig, VortexForCausalLM)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def test_hf_integration():
|
| 185 |
+
"""Test HuggingFace integration."""
|
| 186 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 187 |
+
|
| 188 |
+
# Create config
|
| 189 |
+
config = VortexConfig(
|
| 190 |
+
d_model=512,
|
| 191 |
+
num_layers=2,
|
| 192 |
+
num_heads=8,
|
| 193 |
+
vocab_size=1000,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Create model
|
| 197 |
+
model = VortexForCausalLM(config)
|
| 198 |
+
print(f"Model parameters: {model.get_num_parameters():,}")
|
| 199 |
+
|
| 200 |
+
# Test forward
|
| 201 |
+
batch_size = 2
|
| 202 |
+
seq_len = 32
|
| 203 |
+
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
| 204 |
+
labels = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
| 205 |
+
|
| 206 |
+
outputs = model(input_ids=input_ids, labels=labels)
|
| 207 |
+
print(f"Loss: {outputs.loss.item():.4f}")
|
| 208 |
+
print(f"Logits shape: {outputs.logits.shape}")
|
| 209 |
+
|
| 210 |
+
# Test save/load
|
| 211 |
+
model.save_pretrained("./test_vortex_model")
|
| 212 |
+
config.save_pretrained("./test_vortex_model")
|
| 213 |
+
|
| 214 |
+
loaded_config = AutoConfig.from_pretrained("./test_vortex_model")
|
| 215 |
+
loaded_model = AutoModelForCausalLM.from_pretrained("./test_vortex_model")
|
| 216 |
+
print(f"Loaded model type: {type(loaded_model)}")
|
| 217 |
+
|
| 218 |
+
print("HF integration test passed!")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if __name__ == "__main__":
|
| 222 |
+
test_hf_integration()
|
models/__pycache__/attention_layer.cpython-313.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
models/__pycache__/scigate_ffn.cpython-313.pyc
ADDED
|
Binary file (7.94 kB). View file
|
|
|
models/__pycache__/ssm_layer.cpython-313.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
models/__pycache__/vortex_model.cpython-313.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
models/attention_layer.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VortexLocalAttention: Local windowed attention with global token support.
|
| 3 |
+
Uses a sliding window of 512 tokens for efficiency, with special handling
|
| 4 |
+
for global tokens that can attend across the entire sequence.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from typing import Optional, Tuple
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class VortexLocalAttention(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Local windowed attention with window_size=512.
|
| 16 |
+
Science documents have strong local coherence — equations reference
|
| 17 |
+
nearby text, not distant paragraphs.
|
| 18 |
+
Global tokens (special [SCIENCE] tokens) attend to everything.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
d_model: int,
|
| 24 |
+
num_heads: int,
|
| 25 |
+
window_size: int = 512,
|
| 26 |
+
use_flash_attention: bool = True,
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Initialize local windowed attention.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
d_model: Model dimension
|
| 33 |
+
num_heads: Number of attention heads
|
| 34 |
+
window_size: Size of local attention window
|
| 35 |
+
use_flash_attention: Use Flash Attention 2 if available (CUDA only)
|
| 36 |
+
"""
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.d_model = d_model
|
| 39 |
+
self.num_heads = num_heads
|
| 40 |
+
self.head_dim = d_model // num_heads
|
| 41 |
+
self.window_size = window_size
|
| 42 |
+
self.use_flash_attention = use_flash_attention
|
| 43 |
+
|
| 44 |
+
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
| 45 |
+
|
| 46 |
+
# QKV projection
|
| 47 |
+
self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
|
| 48 |
+
self.out_proj = nn.Linear(d_model, d_model, bias=False)
|
| 49 |
+
|
| 50 |
+
# Global token projection (for tokens that attend globally)
|
| 51 |
+
self.global_qkv = nn.Linear(d_model, d_model * 3, bias=False)
|
| 52 |
+
|
| 53 |
+
# Initialize weights
|
| 54 |
+
self._initialize_weights()
|
| 55 |
+
|
| 56 |
+
def _initialize_weights(self):
|
| 57 |
+
"""Initialize weights."""
|
| 58 |
+
for module in [self.qkv, self.global_qkv, self.out_proj]:
|
| 59 |
+
if hasattr(module, 'weight'):
|
| 60 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 61 |
+
|
| 62 |
+
def forward(
|
| 63 |
+
self,
|
| 64 |
+
x: torch.Tensor,
|
| 65 |
+
global_mask: Optional[torch.Tensor] = None,
|
| 66 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 67 |
+
) -> torch.Tensor:
|
| 68 |
+
"""
|
| 69 |
+
Forward pass with local windowed attention.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
x: Input tensor (batch, seq_len, d_model)
|
| 73 |
+
global_mask: Boolean mask indicating which tokens are global (attend everywhere)
|
| 74 |
+
Shape: (batch, seq_len) or None
|
| 75 |
+
attention_mask: Optional padding mask (batch, seq_len)
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Output tensor (batch, seq_len, d_model)
|
| 79 |
+
"""
|
| 80 |
+
batch, seq_len, _ = x.shape
|
| 81 |
+
device = x.device
|
| 82 |
+
dtype = x.dtype
|
| 83 |
+
|
| 84 |
+
if global_mask is None:
|
| 85 |
+
global_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=device)
|
| 86 |
+
|
| 87 |
+
# Compute QKV for all tokens
|
| 88 |
+
qkv = self.qkv(x)
|
| 89 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 90 |
+
|
| 91 |
+
# Reshape for multi-head attention
|
| 92 |
+
q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 93 |
+
k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 94 |
+
v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 95 |
+
|
| 96 |
+
# Compute global token QKV separately
|
| 97 |
+
if global_mask.any():
|
| 98 |
+
global_qkv = self.global_qkv(x)
|
| 99 |
+
gq, gk, gv = global_qkv.chunk(3, dim=-1)
|
| 100 |
+
gq = gq.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 101 |
+
gk = gk.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 102 |
+
gv = gv.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 103 |
+
|
| 104 |
+
# Build output tensor
|
| 105 |
+
output = torch.zeros_like(x)
|
| 106 |
+
|
| 107 |
+
# Process each position
|
| 108 |
+
for t in range(seq_len):
|
| 109 |
+
# Determine window
|
| 110 |
+
window_start = max(0, t - self.window_size // 2)
|
| 111 |
+
window_end = min(seq_len, t + self.window_size // 2 + 1)
|
| 112 |
+
window_len = window_end - window_start
|
| 113 |
+
|
| 114 |
+
# Get window indices
|
| 115 |
+
window_indices = slice(window_start, window_end)
|
| 116 |
+
|
| 117 |
+
# Extract window queries (for position t)
|
| 118 |
+
q_t = q[:, :, t:t+1, :] # (batch, heads, 1, head_dim)
|
| 119 |
+
|
| 120 |
+
# Determine which keys/values to use
|
| 121 |
+
# Local tokens: only those in window
|
| 122 |
+
# Global tokens: all positions (if they are global)
|
| 123 |
+
k_window = k[:, :, window_indices, :]
|
| 124 |
+
v_window = v[:, :, window_indices, :]
|
| 125 |
+
|
| 126 |
+
# Build full key/value set including global tokens
|
| 127 |
+
# Global tokens attend to all positions
|
| 128 |
+
if global_mask.any():
|
| 129 |
+
# Find global positions
|
| 130 |
+
global_positions = global_mask[0] # (seq_len) - assume same across batch
|
| 131 |
+
if global_positions.any():
|
| 132 |
+
gk_all = gk[:, :, :, :] # All global keys
|
| 133 |
+
gv_all = gv[:, :, :, :]
|
| 134 |
+
|
| 135 |
+
# Concatenate window keys with global keys
|
| 136 |
+
k_full = torch.cat([k_window, gk_all], dim=2)
|
| 137 |
+
v_full = torch.cat([v_window, gv_all], dim=2)
|
| 138 |
+
else:
|
| 139 |
+
k_full = k_window
|
| 140 |
+
v_full = v_window
|
| 141 |
+
else:
|
| 142 |
+
k_full = k_window
|
| 143 |
+
v_full = v_window
|
| 144 |
+
|
| 145 |
+
# Compute attention scores
|
| 146 |
+
# q_t: (batch, heads, 1, head_dim)
|
| 147 |
+
# k_full: (batch, heads, window_len + num_global, head_dim)
|
| 148 |
+
attn_scores = torch.matmul(q_t, k_full.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
| 149 |
+
# (batch, heads, 1, k_len)
|
| 150 |
+
|
| 151 |
+
# Apply attention mask if provided
|
| 152 |
+
if attention_mask is not None:
|
| 153 |
+
mask_t = attention_mask[:, window_indices].unsqueeze(1).unsqueeze(2)
|
| 154 |
+
attn_scores = attn_scores.masked_fill(mask_t == 0, -1e9)
|
| 155 |
+
|
| 156 |
+
# Softmax
|
| 157 |
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
| 158 |
+
|
| 159 |
+
# Weighted sum
|
| 160 |
+
attn_output = torch.matmul(attn_weights, v_full)
|
| 161 |
+
# (batch, heads, 1, head_dim)
|
| 162 |
+
|
| 163 |
+
# Reshape and project
|
| 164 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 165 |
+
attn_output = attn_output.view(batch, 1, self.d_model)
|
| 166 |
+
attn_output = self.out_proj(attn_output)
|
| 167 |
+
|
| 168 |
+
# Place in output
|
| 169 |
+
output[:, t:t+1, :] = attn_output
|
| 170 |
+
|
| 171 |
+
return output
|
| 172 |
+
|
| 173 |
+
def forward_optimized(
|
| 174 |
+
self,
|
| 175 |
+
x: torch.Tensor,
|
| 176 |
+
global_mask: Optional[torch.Tensor] = None,
|
| 177 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 178 |
+
) -> torch.Tensor:
|
| 179 |
+
"""
|
| 180 |
+
Optimized forward pass using Flash Attention or efficient windowed attention.
|
| 181 |
+
This is a placeholder for actual Flash Attention integration.
|
| 182 |
+
"""
|
| 183 |
+
batch, seq_len, _ = x.shape
|
| 184 |
+
|
| 185 |
+
if self.use_flash_attention and self.window_size >= seq_len:
|
| 186 |
+
# For short sequences, can use full attention
|
| 187 |
+
return self._flash_attention_forward(x, attention_mask)
|
| 188 |
+
else:
|
| 189 |
+
# Use windowed attention
|
| 190 |
+
return self._windowed_attention_forward(x, global_mask, attention_mask)
|
| 191 |
+
|
| 192 |
+
def _flash_attention_forward(
|
| 193 |
+
self,
|
| 194 |
+
x: torch.Tensor,
|
| 195 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 196 |
+
) -> torch.Tensor:
|
| 197 |
+
"""
|
| 198 |
+
Use Flash Attention 2 if available.
|
| 199 |
+
Requires: pip install flash-attn
|
| 200 |
+
"""
|
| 201 |
+
try:
|
| 202 |
+
from flash_attn import flash_attn_func
|
| 203 |
+
|
| 204 |
+
batch, seq_len, _ = x.shape
|
| 205 |
+
qkv = self.qkv(x)
|
| 206 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 207 |
+
|
| 208 |
+
# Reshape for flash attention
|
| 209 |
+
q = q.view(batch, seq_len, self.num_heads, self.head_dim)
|
| 210 |
+
k = k.view(batch, seq_len, self.num_heads, self.head_dim)
|
| 211 |
+
v = v.view(batch, seq_len, self.num_heads, self.head_dim)
|
| 212 |
+
|
| 213 |
+
# Flash attention expects (batch, seq_len, num_heads, head_dim)
|
| 214 |
+
# and returns same shape
|
| 215 |
+
if attention_mask is not None:
|
| 216 |
+
# Flash attention uses causal mask or padding mask
|
| 217 |
+
output = flash_attn_func(
|
| 218 |
+
q, k, v,
|
| 219 |
+
causal=False,
|
| 220 |
+
softmax_scale=1.0 / (self.head_dim ** 0.5),
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
output = flash_attn_func(
|
| 224 |
+
q, k, v,
|
| 225 |
+
causal=False,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
output = output.view(batch, seq_len, self.d_model)
|
| 229 |
+
return self.out_proj(output)
|
| 230 |
+
|
| 231 |
+
except ImportError:
|
| 232 |
+
print("Flash Attention not available, falling back to standard attention")
|
| 233 |
+
return self._standard_attention(x, attention_mask)
|
| 234 |
+
|
| 235 |
+
def _standard_attention(
|
| 236 |
+
self,
|
| 237 |
+
x: torch.Tensor,
|
| 238 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 239 |
+
) -> torch.Tensor:
|
| 240 |
+
"""Standard full attention (quadratic)."""
|
| 241 |
+
batch, seq_len, _ = x.shape
|
| 242 |
+
qkv = self.qkv(x)
|
| 243 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 244 |
+
|
| 245 |
+
q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 246 |
+
k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 247 |
+
v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 248 |
+
|
| 249 |
+
# Compute attention scores
|
| 250 |
+
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
| 251 |
+
|
| 252 |
+
if attention_mask is not None:
|
| 253 |
+
attn_scores = attn_scores.masked_fill(
|
| 254 |
+
attention_mask.unsqueeze(1).unsqueeze(2) == 0,
|
| 255 |
+
-1e9
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
| 259 |
+
attn_output = torch.matmul(attn_weights, v)
|
| 260 |
+
|
| 261 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 262 |
+
attn_output = attn_output.view(batch, seq_len, self.d_model)
|
| 263 |
+
return self.out_proj(attn_output)
|
| 264 |
+
|
| 265 |
+
def _windowed_attention_forward(
|
| 266 |
+
self,
|
| 267 |
+
x: torch.Tensor,
|
| 268 |
+
global_mask: Optional[torch.Tensor] = None,
|
| 269 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 270 |
+
) -> torch.Tensor:
|
| 271 |
+
"""
|
| 272 |
+
Efficient windowed attention implementation.
|
| 273 |
+
Uses unfold to extract windows and batched matrix multiply.
|
| 274 |
+
"""
|
| 275 |
+
batch, seq_len, _ = x.shape
|
| 276 |
+
device = x.device
|
| 277 |
+
|
| 278 |
+
if global_mask is None:
|
| 279 |
+
global_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=device)
|
| 280 |
+
|
| 281 |
+
# Compute QKV
|
| 282 |
+
qkv = self.qkv(x)
|
| 283 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 284 |
+
|
| 285 |
+
# Reshape: (batch, seq_len, num_heads, head_dim)
|
| 286 |
+
q = q.view(batch, seq_len, self.num_heads, self.head_dim)
|
| 287 |
+
k = k.view(batch, seq_len, self.num_heads, self.head_dim)
|
| 288 |
+
v = v.view(batch, seq_len, self.num_heads, self.head_dim)
|
| 289 |
+
|
| 290 |
+
# Pad sequence for windowing
|
| 291 |
+
pad_len = self.window_size // 2
|
| 292 |
+
k_padded = F.pad(k, (0, 0, 0, 0, pad_len, pad_len))
|
| 293 |
+
v_padded = F.pad(v, (0, 0, 0, 0, pad_len, pad_len))
|
| 294 |
+
|
| 295 |
+
# Extract windows using unfold
|
| 296 |
+
# (batch, seq_len, num_heads, head_dim) -> (batch, seq_len, window_size, num_heads, head_dim)
|
| 297 |
+
k_windows = k_padded.unfold(1, self.window_size, 1)
|
| 298 |
+
v_windows = v_padded.unfold(1, self.window_size, 1)
|
| 299 |
+
|
| 300 |
+
# Permute to (batch, seq_len, num_heads, window_size, head_dim)
|
| 301 |
+
k_windows = k_windows.permute(0, 1, 3, 2, 4)
|
| 302 |
+
v_windows = v_windows.permute(0, 1, 3, 2, 4)
|
| 303 |
+
|
| 304 |
+
# Compute attention for each position
|
| 305 |
+
# q: (batch, seq_len, num_heads, 1, head_dim)
|
| 306 |
+
q_expanded = q.unsqueeze(3)
|
| 307 |
+
k_windows = k_windows
|
| 308 |
+
|
| 309 |
+
# Scores: (batch, seq_len, num_heads, 1, window_size)
|
| 310 |
+
attn_scores = torch.matmul(q_expanded, k_windows.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
| 311 |
+
attn_scores = attn_scores.squeeze(3) # (batch, seq_len, num_heads, window_size)
|
| 312 |
+
|
| 313 |
+
# Apply softmax
|
| 314 |
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
| 315 |
+
|
| 316 |
+
# Weighted sum
|
| 317 |
+
attn_output = torch.matmul(attn_weights.unsqueeze(3), v_windows).squeeze(3)
|
| 318 |
+
# (batch, seq_len, num_heads, head_dim)
|
| 319 |
+
|
| 320 |
+
# Concatenate heads
|
| 321 |
+
attn_output = attn_output.view(batch, seq_len, self.d_model)
|
| 322 |
+
|
| 323 |
+
# Add global token contribution if any
|
| 324 |
+
if global_mask.any():
|
| 325 |
+
# Compute full attention for global tokens only
|
| 326 |
+
# This is a simplified version - in practice would be optimized
|
| 327 |
+
global_indices = global_mask[0].nonzero(as_tuple=True)[0]
|
| 328 |
+
if len(global_indices) > 0:
|
| 329 |
+
# For positions with global tokens, add full attention
|
| 330 |
+
# (simplified: compute full attention for all)
|
| 331 |
+
full_attn = self._standard_attention(x, attention_mask)
|
| 332 |
+
# Blend: local for most, full for global positions
|
| 333 |
+
attn_output = torch.where(
|
| 334 |
+
global_mask.unsqueeze(-1),
|
| 335 |
+
full_attn,
|
| 336 |
+
attn_output
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
return self.out_proj(attn_output)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def test_vortex_local_attention():
|
| 343 |
+
"""Test the VortexLocalAttention layer."""
|
| 344 |
+
batch_size = 2
|
| 345 |
+
seq_len = 256
|
| 346 |
+
d_model = 4096
|
| 347 |
+
num_heads = 32
|
| 348 |
+
window_size = 512
|
| 349 |
+
|
| 350 |
+
attn = VortexLocalAttention(d_model, num_heads, window_size, use_flash_attention=False)
|
| 351 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 352 |
+
|
| 353 |
+
# Forward pass
|
| 354 |
+
output = attn(x)
|
| 355 |
+
print(f"Input shape: {x.shape}")
|
| 356 |
+
print(f"Output shape: {output.shape}")
|
| 357 |
+
assert output.shape == x.shape, f"Expected {x.shape}, got {output.shape}"
|
| 358 |
+
|
| 359 |
+
# With global mask
|
| 360 |
+
global_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
|
| 361 |
+
global_mask[0, 0] = True # First token is global
|
| 362 |
+
global_mask[1, -1] = True # Last token is global
|
| 363 |
+
output2 = attn(x, global_mask=global_mask)
|
| 364 |
+
assert output2.shape == x.shape
|
| 365 |
+
|
| 366 |
+
print("VortexLocalAttention test passed!")
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
if __name__ == "__main__":
|
| 370 |
+
test_vortex_local_attention()
|
models/science_modules/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Science modules package.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .equation_module import EquationModule
|
| 6 |
+
from .numerical_module import NumericalReasoningModule
|
| 7 |
+
from .citation_module import CitationModule
|
| 8 |
+
from .molecular_module import MolecularModule
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"EquationModule",
|
| 12 |
+
"NumericalReasoningModule",
|
| 13 |
+
"CitationModule",
|
| 14 |
+
"MolecularModule",
|
| 15 |
+
]
|
models/science_modules/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (490 Bytes). View file
|
|
|
models/science_modules/__pycache__/citation_module.cpython-313.pyc
ADDED
|
Binary file (9.3 kB). View file
|
|
|
models/science_modules/__pycache__/equation_module.cpython-313.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
models/science_modules/__pycache__/molecular_module.cpython-313.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
models/science_modules/__pycache__/numerical_module.cpython-313.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
models/science_modules/citation_module.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CitationModule: Understands scientific citation structure.
|
| 3 |
+
Detects citation spans, tracks provenance, and estimates claim confidence.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import re
|
| 10 |
+
from typing import Optional, Tuple, List
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CitationModule(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Understands scientific citation structure.
|
| 16 |
+
- Detects citation spans [Author, Year] or (1) style
|
| 17 |
+
- Learns that cited claims carry different epistemic weight
|
| 18 |
+
- Distinguishes established facts vs recent/contested findings
|
| 19 |
+
- Tracks claim provenance through the context window
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, d_model: int):
|
| 23 |
+
"""
|
| 24 |
+
Initialize CitationModule.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
d_model: Model dimension
|
| 28 |
+
"""
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.d_model = d_model
|
| 31 |
+
|
| 32 |
+
# Citation span detector (3 classes: none, inline, reference)
|
| 33 |
+
# Inline: (Author, Year) or [1]
|
| 34 |
+
# Reference: full citation at end of paper
|
| 35 |
+
self.citation_detector = nn.Linear(d_model, 3)
|
| 36 |
+
|
| 37 |
+
# Provenance gate: modulates information flow based on citation context
|
| 38 |
+
self.provenance_gate = nn.Linear(d_model, d_model)
|
| 39 |
+
|
| 40 |
+
# Claim confidence head: estimates how well-supported a claim is
|
| 41 |
+
self.confidence_head = nn.Linear(d_model, 1)
|
| 42 |
+
|
| 43 |
+
# Citation type embeddings
|
| 44 |
+
self.citation_type_embedding = nn.Embedding(3, d_model)
|
| 45 |
+
|
| 46 |
+
# Initialize weights
|
| 47 |
+
self._initialize_weights()
|
| 48 |
+
|
| 49 |
+
def _initialize_weights(self):
|
| 50 |
+
"""Initialize weights."""
|
| 51 |
+
for module in [self.citation_detector, self.provenance_gate, self.confidence_head, self.citation_type_embedding]:
|
| 52 |
+
if hasattr(module, 'weight'):
|
| 53 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 54 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 55 |
+
nn.init.zeros_(module.bias)
|
| 56 |
+
|
| 57 |
+
def detect_citation_spans(
|
| 58 |
+
self,
|
| 59 |
+
text: str,
|
| 60 |
+
) -> List[Tuple[int, int, str]]:
|
| 61 |
+
"""
|
| 62 |
+
Detect citation spans in text.
|
| 63 |
+
Supports: (Author, Year), [1], [Author, Year], et al.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
text: Input text string
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
List of (start_char, end_char, citation_type)
|
| 70 |
+
citation_type: "inline" or "reference"
|
| 71 |
+
"""
|
| 72 |
+
spans = []
|
| 73 |
+
|
| 74 |
+
# Pattern 1: (Author, Year) or (Author Year)
|
| 75 |
+
for match in re.finditer(r'\([A-Za-z\s]+(?:et al\.)?,?\s*\d{4}\)', text):
|
| 76 |
+
spans.append((match.start(), match.end(), "inline"))
|
| 77 |
+
|
| 78 |
+
# Pattern 2: [1] or [1-3] or [1,2,3]
|
| 79 |
+
for match in re.finditer(r'\[\d+(?:[-,]\d+)*\]', text):
|
| 80 |
+
spans.append((match.start(), match.end(), "inline"))
|
| 81 |
+
|
| 82 |
+
# Pattern 3: [Author, Year]
|
| 83 |
+
for match in re.finditer(r'\[[A-Za-z\s]+,?\s*\d{4}\]', text):
|
| 84 |
+
spans.append((match.start(), match.end(), "inline"))
|
| 85 |
+
|
| 86 |
+
# Pattern 4: et al. (often indicates citation)
|
| 87 |
+
for match in re.finditer(r'\bet al\.\b', text):
|
| 88 |
+
spans.append((match.start(), match.end(), "inline"))
|
| 89 |
+
|
| 90 |
+
return spans
|
| 91 |
+
|
| 92 |
+
def forward(
|
| 93 |
+
self,
|
| 94 |
+
x: torch.Tensor,
|
| 95 |
+
text: Optional[List[str]] = None,
|
| 96 |
+
citation_spans: Optional[List[List[Tuple[int, int, str]]]] = None,
|
| 97 |
+
) -> torch.Tensor:
|
| 98 |
+
"""
|
| 99 |
+
Forward pass through citation module.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
x: Input tensor (batch, seq_len, d_model)
|
| 103 |
+
text: Optional original text strings
|
| 104 |
+
citation_spans: Optional pre-computed citation spans per batch
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Citation-enhanced representation (batch, seq_len, d_model)
|
| 108 |
+
"""
|
| 109 |
+
batch, seq_len, d_model = x.shape
|
| 110 |
+
|
| 111 |
+
# Detect citation spans
|
| 112 |
+
if citation_spans is None and text is not None:
|
| 113 |
+
citation_spans = []
|
| 114 |
+
for b in range(batch):
|
| 115 |
+
spans = self.detect_citation_spans(text[b])
|
| 116 |
+
# Convert char spans to token spans (approximate)
|
| 117 |
+
token_spans = []
|
| 118 |
+
for start_char, end_char, ctype in spans:
|
| 119 |
+
start_tok = max(0, start_char // 4)
|
| 120 |
+
end_tok = min(seq_len, end_char // 4 + 1)
|
| 121 |
+
token_spans.append((start_tok, end_tok, ctype))
|
| 122 |
+
citation_spans.append(token_spans)
|
| 123 |
+
|
| 124 |
+
# Compute citation type logits
|
| 125 |
+
citation_logits = self.citation_detector(x) # (batch, seq_len, 3)
|
| 126 |
+
citation_probs = F.softmax(citation_logits, dim=-1)
|
| 127 |
+
|
| 128 |
+
# Apply citation-specific transformations
|
| 129 |
+
output = x.clone()
|
| 130 |
+
|
| 131 |
+
if citation_spans:
|
| 132 |
+
for b in range(batch):
|
| 133 |
+
spans_b = citation_spans[b] if b < len(citation_spans) else []
|
| 134 |
+
|
| 135 |
+
for start_tok, end_tok, ctype in spans_b:
|
| 136 |
+
if end_tok <= start_tok:
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
# Get citation type embedding
|
| 140 |
+
if ctype == "inline":
|
| 141 |
+
type_id = 1
|
| 142 |
+
elif ctype == "reference":
|
| 143 |
+
type_id = 2
|
| 144 |
+
else:
|
| 145 |
+
type_id = 0
|
| 146 |
+
|
| 147 |
+
type_emb = self.citation_type_embedding(
|
| 148 |
+
torch.tensor(type_id, device=x.device)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Apply provenance gate to citation span
|
| 152 |
+
span_slice = x[b, start_tok:end_tok, :]
|
| 153 |
+
gated = span_slice * torch.sigmoid(self.provenance_gate(span_slice))
|
| 154 |
+
|
| 155 |
+
# Add citation type embedding
|
| 156 |
+
gated = gated + type_emb.unsqueeze(0).unsqueeze(0)
|
| 157 |
+
|
| 158 |
+
output[b, start_tok:end_tok, :] = gated
|
| 159 |
+
|
| 160 |
+
# Compute confidence scores (for auxiliary loss)
|
| 161 |
+
confidence = torch.sigmoid(self.confidence_head(x)) # (batch, seq_len, 1)
|
| 162 |
+
|
| 163 |
+
return output, confidence
|
| 164 |
+
|
| 165 |
+
def compute_citation_loss(
|
| 166 |
+
self,
|
| 167 |
+
x: torch.Tensor,
|
| 168 |
+
citation_mask: torch.Tensor,
|
| 169 |
+
confidence: torch.Tensor,
|
| 170 |
+
) -> torch.Tensor:
|
| 171 |
+
"""
|
| 172 |
+
Compute auxiliary loss for citation detection and confidence.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
x: Input tensor (batch, seq_len, d_model)
|
| 176 |
+
citation_mask: Ground truth citation mask (batch, seq_len), 1 if token is in citation
|
| 177 |
+
confidence: Predicted confidence scores (batch, seq_len, 1)
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Combined citation loss
|
| 181 |
+
"""
|
| 182 |
+
# Citation detection loss
|
| 183 |
+
logits = self.citation_detector(x) # (batch, seq_len, 3)
|
| 184 |
+
detection_loss = F.cross_entropy(
|
| 185 |
+
logits.view(-1, 3),
|
| 186 |
+
citation_mask.long().view(-1),
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Confidence calibration loss (encourage high confidence for true citations)
|
| 190 |
+
confidence_loss = F.mse_loss(
|
| 191 |
+
confidence.squeeze(-1),
|
| 192 |
+
citation_mask.float(),
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
return detection_loss + 0.1 * confidence_loss
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def test_citation_module():
|
| 199 |
+
"""Test CitationModule."""
|
| 200 |
+
d_model = 512
|
| 201 |
+
batch_size = 2
|
| 202 |
+
seq_len = 128
|
| 203 |
+
|
| 204 |
+
module = CitationModule(d_model)
|
| 205 |
+
|
| 206 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 207 |
+
text = [
|
| 208 |
+
"The theory of relativity (Einstein, 1905) revolutionized physics. See also [1, 2].",
|
| 209 |
+
"According to Smith et al., the results are significant. Further reading: [Doe, 2020]."
|
| 210 |
+
]
|
| 211 |
+
|
| 212 |
+
output, confidence = module(x, text=text)
|
| 213 |
+
print(f"Input shape: {x.shape}")
|
| 214 |
+
print(f"Output shape: {output.shape}")
|
| 215 |
+
print(f"Confidence shape: {confidence.shape}")
|
| 216 |
+
assert output.shape == x.shape
|
| 217 |
+
assert confidence.shape == (batch_size, seq_len, 1)
|
| 218 |
+
|
| 219 |
+
# Test loss
|
| 220 |
+
citation_mask = torch.zeros(batch_size, seq_len)
|
| 221 |
+
citation_mask[0, 20:25] = 1.0 # Simulate citation span
|
| 222 |
+
citation_mask[1, 10:18] = 1.0
|
| 223 |
+
loss = module.compute_citation_loss(x, citation_mask, confidence)
|
| 224 |
+
print(f"Citation loss: {loss.item():.4f}")
|
| 225 |
+
|
| 226 |
+
print("CitationModule test passed!")
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
if __name__ == "__main__":
|
| 230 |
+
test_citation_module()
|
models/science_modules/equation_module.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
EquationModule: Specialized processing for mathematical equations and LaTeX.
|
| 3 |
+
Detects equation spans, applies equation-specific attention, and learns
|
| 4 |
+
structural representations of mathematical expressions.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import re
|
| 11 |
+
from typing import Optional, Tuple, List
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class EquationModule(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
Specialized processing for mathematical equations and LaTeX.
|
| 17 |
+
- Detects equation spans in input (between $ $ or \[ \] delimiters)
|
| 18 |
+
- Applies equation-specific attention patterns within equation spans
|
| 19 |
+
- Learns structural representations of mathematical expressions
|
| 20 |
+
- Tree-aware: understands operator precedence and nesting
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, d_model: int, num_heads: int = 8):
|
| 24 |
+
"""
|
| 25 |
+
Initialize EquationModule.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
d_model: Model dimension
|
| 29 |
+
num_heads: Number of heads for equation-specific attention
|
| 30 |
+
"""
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.d_model = d_model
|
| 33 |
+
|
| 34 |
+
# Equation span detector (lightweight linear classifier)
|
| 35 |
+
self.span_detector = nn.Linear(d_model, 1)
|
| 36 |
+
|
| 37 |
+
# Equation-specific transformer (shallow, 2 layers)
|
| 38 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 39 |
+
d_model=d_model,
|
| 40 |
+
nhead=num_heads,
|
| 41 |
+
dim_feedforward=d_model * 4,
|
| 42 |
+
activation=F.silu,
|
| 43 |
+
batch_first=True,
|
| 44 |
+
dropout=0.1,
|
| 45 |
+
)
|
| 46 |
+
self.equation_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
|
| 47 |
+
|
| 48 |
+
# Merge equation representations back into main stream
|
| 49 |
+
self.merge = nn.Linear(d_model * 2, d_model)
|
| 50 |
+
|
| 51 |
+
# LaTeX structure awareness (simple positional encoding for tree depth)
|
| 52 |
+
self.depth_embedding = nn.Embedding(10, d_model) # Max depth 10
|
| 53 |
+
|
| 54 |
+
# Initialize weights
|
| 55 |
+
self._initialize_weights()
|
| 56 |
+
|
| 57 |
+
def _initialize_weights(self):
|
| 58 |
+
"""Initialize weights."""
|
| 59 |
+
for module in [self.span_detector, self.merge, self.depth_embedding]:
|
| 60 |
+
if hasattr(module, 'weight'):
|
| 61 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 62 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 63 |
+
nn.init.zeros_(module.bias)
|
| 64 |
+
|
| 65 |
+
def detect_equation_spans(
|
| 66 |
+
self,
|
| 67 |
+
text: str,
|
| 68 |
+
token_ids: Optional[torch.Tensor] = None,
|
| 69 |
+
) -> List[Tuple[int, int]]:
|
| 70 |
+
"""
|
| 71 |
+
Detect equation spans in text using delimiters.
|
| 72 |
+
Supports: $...$, $$...$$, \[...\], \(...\)
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
text: Input text string
|
| 76 |
+
token_ids: Optional token IDs for alignment
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
List of (start_char, end_char) spans
|
| 80 |
+
"""
|
| 81 |
+
spans = []
|
| 82 |
+
|
| 83 |
+
# Pattern 1: $...$ (inline math)
|
| 84 |
+
for match in re.finditer(r'\$(.+?)\$', text, re.DOTALL):
|
| 85 |
+
spans.append((match.start(), match.end()))
|
| 86 |
+
|
| 87 |
+
# Pattern 2: $$...$$ (display math)
|
| 88 |
+
for match in re.finditer(r'\$\$(.+?)\$\$', text, re.DOTALL):
|
| 89 |
+
spans.append((match.start(), match.end()))
|
| 90 |
+
|
| 91 |
+
# Pattern 3: \[...\] (LaTeX display math)
|
| 92 |
+
for match in re.finditer(r'\\\[(.+?)\\\]', text, re.DOTALL):
|
| 93 |
+
spans.append((match.start(), match.end()))
|
| 94 |
+
|
| 95 |
+
# Pattern 4: \(...\) (LaTeX inline math)
|
| 96 |
+
for match in re.finditer(r'\\\((.+?)\\\)', text, re.DOTALL):
|
| 97 |
+
spans.append((match.start(), match.end()))
|
| 98 |
+
|
| 99 |
+
return spans
|
| 100 |
+
|
| 101 |
+
def forward(
|
| 102 |
+
self,
|
| 103 |
+
x: torch.Tensor,
|
| 104 |
+
text: Optional[List[str]] = None,
|
| 105 |
+
token_spans: Optional[List[List[Tuple[int, int]]]] = None,
|
| 106 |
+
) -> torch.Tensor:
|
| 107 |
+
"""
|
| 108 |
+
Forward pass through the equation module.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
x: Input tensor (batch, seq_len, d_model)
|
| 112 |
+
text: Optional original text strings (for delimiter-based detection)
|
| 113 |
+
token_spans: Optional pre-computed token-level equation spans
|
| 114 |
+
Each element: list of (start_token, end_token) for that batch item
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Equation-enhanced representation (batch, seq_len, d_model)
|
| 118 |
+
"""
|
| 119 |
+
batch, seq_len, d_model = x.shape
|
| 120 |
+
|
| 121 |
+
# Detect equation spans
|
| 122 |
+
if token_spans is None and text is not None:
|
| 123 |
+
# Use delimiter-based detection (requires text)
|
| 124 |
+
token_spans = []
|
| 125 |
+
for b in range(batch):
|
| 126 |
+
char_spans = self.detect_equation_spans(text[b])
|
| 127 |
+
# Convert char spans to token spans (simplified - assumes 1 char ≈ 1 token)
|
| 128 |
+
# In practice, would need proper tokenization alignment
|
| 129 |
+
token_spans_b = []
|
| 130 |
+
for start_char, end_char in char_spans:
|
| 131 |
+
# Rough approximation: divide by average chars per token (~4)
|
| 132 |
+
start_token = max(0, start_char // 4)
|
| 133 |
+
end_token = min(seq_len, end_char // 4 + 1)
|
| 134 |
+
token_spans_b.append((start_token, end_token))
|
| 135 |
+
token_spans.append(token_spans_b)
|
| 136 |
+
elif token_spans is None:
|
| 137 |
+
# Fallback: use learned detector
|
| 138 |
+
token_spans = self._learned_span_detection(x)
|
| 139 |
+
|
| 140 |
+
# Process each batch item
|
| 141 |
+
output = x.clone()
|
| 142 |
+
|
| 143 |
+
for b in range(batch):
|
| 144 |
+
spans_b = token_spans[b] if b < len(token_spans) else []
|
| 145 |
+
|
| 146 |
+
for start_tok, end_tok in spans_b:
|
| 147 |
+
if end_tok <= start_tok:
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
# Extract equation segment
|
| 151 |
+
eq_segment = x[b:b+1, start_tok:end_tok, :] # (1, seg_len, d_model)
|
| 152 |
+
|
| 153 |
+
# Apply equation-specific transformer
|
| 154 |
+
eq_encoded = self.equation_encoder(eq_segment)
|
| 155 |
+
|
| 156 |
+
# Merge with original
|
| 157 |
+
merged = torch.cat([eq_segment, eq_encoded], dim=-1)
|
| 158 |
+
merged = self.merge(merged)
|
| 159 |
+
|
| 160 |
+
# Place back in output
|
| 161 |
+
output[b:b+1, start_tok:end_tok, :] = merged
|
| 162 |
+
|
| 163 |
+
return output
|
| 164 |
+
|
| 165 |
+
def _learned_span_detection(
|
| 166 |
+
self,
|
| 167 |
+
x: torch.Tensor,
|
| 168 |
+
) -> List[List[Tuple[int, int]]]:
|
| 169 |
+
"""
|
| 170 |
+
Use learned detector to find equation spans when delimiters missing.
|
| 171 |
+
Simple thresholding on span_detector output.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
x: Input tensor (batch, seq_len, d_model)
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
List of token spans per batch item
|
| 178 |
+
"""
|
| 179 |
+
batch, seq_len, _ = x.shape
|
| 180 |
+
|
| 181 |
+
# Compute equation probability per token
|
| 182 |
+
eq_probs = torch.sigmoid(self.span_detector(x)) # (batch, seq_len, 1)
|
| 183 |
+
eq_probs = eq_probs.squeeze(-1) # (batch, seq_len)
|
| 184 |
+
|
| 185 |
+
# Threshold
|
| 186 |
+
threshold = 0.5
|
| 187 |
+
spans = []
|
| 188 |
+
|
| 189 |
+
for b in range(batch):
|
| 190 |
+
probs = eq_probs[b]
|
| 191 |
+
is_equation = (probs > threshold).cpu().numpy()
|
| 192 |
+
|
| 193 |
+
# Find contiguous spans
|
| 194 |
+
span_list = []
|
| 195 |
+
in_span = False
|
| 196 |
+
start = 0
|
| 197 |
+
|
| 198 |
+
for t in range(seq_len):
|
| 199 |
+
if is_equation[t] and not in_span:
|
| 200 |
+
start = t
|
| 201 |
+
in_span = True
|
| 202 |
+
elif not is_equation[t] and in_span:
|
| 203 |
+
span_list.append((start, t))
|
| 204 |
+
in_span = False
|
| 205 |
+
|
| 206 |
+
if in_span:
|
| 207 |
+
span_list.append((start, seq_len))
|
| 208 |
+
|
| 209 |
+
spans.append(span_list)
|
| 210 |
+
|
| 211 |
+
return spans
|
| 212 |
+
|
| 213 |
+
def compute_equation_loss(
|
| 214 |
+
self,
|
| 215 |
+
x: torch.Tensor,
|
| 216 |
+
equation_mask: torch.Tensor,
|
| 217 |
+
) -> torch.Tensor:
|
| 218 |
+
"""
|
| 219 |
+
Compute auxiliary loss for equation detection training.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
x: Input tensor (batch, seq_len, d_model)
|
| 223 |
+
equation_mask: Ground truth equation mask (batch, seq_len), 1 if token is in equation
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
Binary cross-entropy loss for equation detection
|
| 227 |
+
"""
|
| 228 |
+
logits = self.span_detector(x).squeeze(-1) # (batch, seq_len)
|
| 229 |
+
loss = F.binary_cross_entropy_with_logits(
|
| 230 |
+
logits,
|
| 231 |
+
equation_mask.float(),
|
| 232 |
+
)
|
| 233 |
+
return loss
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def test_equation_module():
|
| 237 |
+
"""Test EquationModule."""
|
| 238 |
+
d_model = 512
|
| 239 |
+
batch_size = 2
|
| 240 |
+
seq_len = 128
|
| 241 |
+
|
| 242 |
+
module = EquationModule(d_model)
|
| 243 |
+
|
| 244 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 245 |
+
text = [
|
| 246 |
+
"The energy is $E = mc^2$ and momentum is $p = mv$.",
|
| 247 |
+
"Equation: \[ F = ma \] and also $a^2 + b^2 = c^2$."
|
| 248 |
+
]
|
| 249 |
+
|
| 250 |
+
output = module(x, text=text)
|
| 251 |
+
print(f"Input shape: {x.shape}")
|
| 252 |
+
print(f"Output shape: {output.shape}")
|
| 253 |
+
assert output.shape == x.shape
|
| 254 |
+
|
| 255 |
+
# Test equation loss
|
| 256 |
+
equation_mask = torch.zeros(batch_size, seq_len)
|
| 257 |
+
equation_mask[0, 10:15] = 1.0 # Simulate equation span
|
| 258 |
+
equation_mask[1, 5:12] = 1.0
|
| 259 |
+
loss = module.compute_equation_loss(x, equation_mask)
|
| 260 |
+
print(f"Equation loss: {loss.item():.4f}")
|
| 261 |
+
|
| 262 |
+
print("EquationModule test passed!")
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
if __name__ == "__main__":
|
| 266 |
+
test_equation_module()
|
models/science_modules/molecular_module.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MolecularModule: Domain knowledge for chemistry and biology.
|
| 3 |
+
Element embeddings, SMILES understanding, bond types, amino acids.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from typing import Optional, Tuple, List
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MolecularModule(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Domain knowledge for chemistry and biology.
|
| 16 |
+
- All 118 elements as learned embeddings with properties
|
| 17 |
+
(atomic number, mass, electronegativity, valence electrons)
|
| 18 |
+
- SMILES string understanding for molecular structures
|
| 19 |
+
- Bond type awareness (covalent, ionic, hydrogen, van der Waals)
|
| 20 |
+
- Amino acid sequence understanding for biology/zoology
|
| 21 |
+
- Molecular formula → property reasoning
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, d_model: int, num_elements: int = 118):
|
| 25 |
+
"""
|
| 26 |
+
Initialize MolecularModule.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
d_model: Model dimension
|
| 30 |
+
num_elements: Number of chemical elements (default 118)
|
| 31 |
+
"""
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.d_model = d_model
|
| 34 |
+
self.num_elements = num_elements
|
| 35 |
+
|
| 36 |
+
# Element embeddings — 118 elements
|
| 37 |
+
self.element_embed = nn.Embedding(num_elements + 1, d_model) # +1 for unknown
|
| 38 |
+
|
| 39 |
+
# Element property encoder (12 properties)
|
| 40 |
+
# [atomic_number, mass, electronegativity, valence_e, period, group,
|
| 41 |
+
# atomic_radius, ionization_energy, electron_affinity, density,
|
| 42 |
+
# melting_point, boiling_point]
|
| 43 |
+
self.property_proj = nn.Linear(12, d_model)
|
| 44 |
+
|
| 45 |
+
# Bond type embeddings (8 types)
|
| 46 |
+
# 0: none, 1: single, 2: double, 3: triple, 4: aromatic,
|
| 47 |
+
# 5: ionic, 6: hydrogen, 7: van der waals
|
| 48 |
+
self.bond_embed = nn.Embedding(8, d_model)
|
| 49 |
+
|
| 50 |
+
# Amino acid embeddings (20 standard + special)
|
| 51 |
+
self.amino_acid_vocab = 25 # 20 standard + stop + start + unknown + special
|
| 52 |
+
self.amino_embed = nn.Embedding(self.amino_acid_vocab, d_model)
|
| 53 |
+
|
| 54 |
+
# Molecular graph attention (treats molecules as graphs)
|
| 55 |
+
self.mol_attention = nn.MultiheadAttention(
|
| 56 |
+
d_model,
|
| 57 |
+
num_heads=8,
|
| 58 |
+
batch_first=True,
|
| 59 |
+
dropout=0.1,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Property prediction head (for auxiliary tasks)
|
| 63 |
+
self.property_head = nn.Linear(d_model, 12)
|
| 64 |
+
|
| 65 |
+
# Initialize weights
|
| 66 |
+
self._initialize_weights()
|
| 67 |
+
|
| 68 |
+
# Pre-compute element properties (simplified)
|
| 69 |
+
self._init_element_properties()
|
| 70 |
+
|
| 71 |
+
def _initialize_weights(self):
|
| 72 |
+
"""Initialize weights."""
|
| 73 |
+
for module in [self.element_embed, self.property_proj, self.bond_embed,
|
| 74 |
+
self.amino_embed, self.mol_attention, self.property_head]:
|
| 75 |
+
if hasattr(module, 'weight'):
|
| 76 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 77 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 78 |
+
nn.init.zeros_(module.bias)
|
| 79 |
+
|
| 80 |
+
def _init_element_properties(self):
|
| 81 |
+
"""Initialize element property table with approximate values."""
|
| 82 |
+
# This is a simplified version - in practice would load from database
|
| 83 |
+
# Properties: [atomic_number, mass, electronegativity, valence_e, period, group,
|
| 84 |
+
# atomic_radius, ionization_energy, electron_affinity, density,
|
| 85 |
+
# melting_point, boiling_point]
|
| 86 |
+
properties = torch.zeros(self.num_elements + 1, 12)
|
| 87 |
+
|
| 88 |
+
# Fill in known elements (simplified data for first 20 + some common ones)
|
| 89 |
+
# Real implementation would use a comprehensive chemistry database
|
| 90 |
+
element_data = {
|
| 91 |
+
1: [1, 1.008, 2.20, 1, 1, 1, 25, 1312, 72.8, 0.0000899, 14, 20],
|
| 92 |
+
6: [6, 12.011, 2.55, 4, 2, 14, 70, 1086, 153.9, 2.267, 3550, 4027],
|
| 93 |
+
7: [7, 14.007, 3.04, 5, 2, 15, 65, 1402, 7.0, 0.0012506, 63, 77],
|
| 94 |
+
8: [8, 15.999, 3.44, 6, 2, 16, 60, 1314, 141.0, 0.001429, 55, 90],
|
| 95 |
+
# ... would fill all 118 elements
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
for z, props in element_data.items():
|
| 99 |
+
properties[z] = torch.tensor(props)
|
| 100 |
+
|
| 101 |
+
self.register_buffer("element_properties", properties)
|
| 102 |
+
|
| 103 |
+
def detect_molecular_spans(
|
| 104 |
+
self,
|
| 105 |
+
text: str,
|
| 106 |
+
) -> List[Tuple[int, int, str]]:
|
| 107 |
+
"""
|
| 108 |
+
Detect molecular/chemical spans in text.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
text: Input text string
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
List of (start_char, end_char, span_type)
|
| 115 |
+
span_type: "formula", "smiles", "amino_acid"
|
| 116 |
+
"""
|
| 117 |
+
spans = []
|
| 118 |
+
|
| 119 |
+
# Chemical formulas: H2O, CO2, C6H12O6, NaCl, HCl
|
| 120 |
+
formula_pattern = r'\b([A-Z][a-z]?\d*)+(?:[A-Z][a-z]?\d*)*\b'
|
| 121 |
+
for match in re.finditer(formula_pattern, text):
|
| 122 |
+
# Filter out single letters that are not formulas
|
| 123 |
+
span = match.group()
|
| 124 |
+
if len(span) > 1 or span.isupper():
|
| 125 |
+
spans.append((match.start(), match.end(), "formula"))
|
| 126 |
+
|
| 127 |
+
# SMILES patterns (simplified detection)
|
| 128 |
+
# Contains: =, #, @, [], (), numbers in sequence
|
| 129 |
+
smiles_hints = ['=', '#', '@', '[', ']', '(', ')']
|
| 130 |
+
words = re.findall(r'\S+', text)
|
| 131 |
+
for word in words:
|
| 132 |
+
if any(hint in word for hint in smiles_hints) and len(word) > 3:
|
| 133 |
+
# Find position in text
|
| 134 |
+
pos = text.find(word)
|
| 135 |
+
if pos >= 0:
|
| 136 |
+
spans.append((pos, pos + len(word), "smiles"))
|
| 137 |
+
|
| 138 |
+
# Amino acid sequences (single letters, length > 5)
|
| 139 |
+
aa_pattern = r'\b([ACDEFGHIKLMNPQRSTVWY]{6,})\b'
|
| 140 |
+
for match in re.finditer(aa_pattern, text.upper()):
|
| 141 |
+
spans.append((match.start(), match.end(), "amino_acid"))
|
| 142 |
+
|
| 143 |
+
return spans
|
| 144 |
+
|
| 145 |
+
def encode_molecule(
|
| 146 |
+
self,
|
| 147 |
+
formula: str,
|
| 148 |
+
) -> torch.Tensor:
|
| 149 |
+
"""
|
| 150 |
+
Encode a molecular formula into embedding.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
formula: Chemical formula string (e.g., "C6H12O6")
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Molecule embedding (d_model,)
|
| 157 |
+
"""
|
| 158 |
+
# Parse formula into elements and counts
|
| 159 |
+
# Simplified parser - real would handle nested parentheses
|
| 160 |
+
pattern = r'([A-Z][a-z]?)(\d*)'
|
| 161 |
+
matches = re.findall(pattern, formula)
|
| 162 |
+
|
| 163 |
+
device = self.element_embed.weight.device
|
| 164 |
+
embeddings = []
|
| 165 |
+
weights = []
|
| 166 |
+
|
| 167 |
+
for element, count_str in matches:
|
| 168 |
+
# Get element atomic number (simplified mapping)
|
| 169 |
+
element_map = {
|
| 170 |
+
'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8,
|
| 171 |
+
'F': 9, 'Ne': 10, 'Na': 11, 'Mg': 12, 'Al': 13, 'Si': 14, 'P': 15,
|
| 172 |
+
'S': 16, 'Cl': 17, 'Ar': 18, 'K': 19, 'Ca': 20,
|
| 173 |
+
# ... extend as needed
|
| 174 |
+
}
|
| 175 |
+
z = element_map.get(element, 0) # 0 = unknown
|
| 176 |
+
|
| 177 |
+
count = int(count_str) if count_str else 1
|
| 178 |
+
|
| 179 |
+
# Get element embedding
|
| 180 |
+
elem_emb = self.element_embed(torch.tensor(z, device=device))
|
| 181 |
+
|
| 182 |
+
# Get properties and project
|
| 183 |
+
props = self.element_properties[z].unsqueeze(0) # (1, 12)
|
| 184 |
+
props_emb = self.property_proj(props).squeeze(0)
|
| 185 |
+
|
| 186 |
+
# Combine
|
| 187 |
+
combined = elem_emb + props_emb
|
| 188 |
+
embeddings.append(combined)
|
| 189 |
+
weights.append(count)
|
| 190 |
+
|
| 191 |
+
if not embeddings:
|
| 192 |
+
# Return zero embedding
|
| 193 |
+
return torch.zeros(self.d_model, device=device)
|
| 194 |
+
|
| 195 |
+
# Weighted average
|
| 196 |
+
embeddings = torch.stack(embeddings)
|
| 197 |
+
weights = torch.tensor(weights, dtype=torch.float32, device=device)
|
| 198 |
+
weights = weights / weights.sum()
|
| 199 |
+
|
| 200 |
+
return (embeddings * weights.unsqueeze(-1)).sum(dim=0)
|
| 201 |
+
|
| 202 |
+
def forward(
|
| 203 |
+
self,
|
| 204 |
+
x: torch.Tensor,
|
| 205 |
+
text: Optional[List[str]] = None,
|
| 206 |
+
molecular_spans: Optional[List[List[Tuple[int, int, str]]]] = None,
|
| 207 |
+
) -> torch.Tensor:
|
| 208 |
+
"""
|
| 209 |
+
Forward pass through molecular module.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
x: Input tensor (batch, seq_len, d_model)
|
| 213 |
+
text: Optional original text strings
|
| 214 |
+
molecular_spans: Optional pre-computed molecular spans per batch
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
Molecular-enhanced representation (batch, seq_len, d_model)
|
| 218 |
+
"""
|
| 219 |
+
batch, seq_len, d_model = x.shape
|
| 220 |
+
device = x.device
|
| 221 |
+
|
| 222 |
+
# Detect molecular spans
|
| 223 |
+
if molecular_spans is None and text is not None:
|
| 224 |
+
molecular_spans = []
|
| 225 |
+
for b in range(batch):
|
| 226 |
+
spans = self.detect_molecular_spans(text[b])
|
| 227 |
+
# Convert char spans to token spans
|
| 228 |
+
token_spans = []
|
| 229 |
+
for start_char, end_char, span_type in spans:
|
| 230 |
+
start_tok = max(0, start_char // 4)
|
| 231 |
+
end_tok = min(seq_len, end_char // 4 + 1)
|
| 232 |
+
token_spans.append((start_tok, end_tok, span_type))
|
| 233 |
+
molecular_spans.append(token_spans)
|
| 234 |
+
|
| 235 |
+
# Enhance molecular spans
|
| 236 |
+
output = x.clone()
|
| 237 |
+
|
| 238 |
+
if molecular_spans:
|
| 239 |
+
for b in range(batch):
|
| 240 |
+
spans_b = molecular_spans[b] if b < len(molecular_spans) else []
|
| 241 |
+
|
| 242 |
+
for start_tok, end_tok, span_type in spans_b:
|
| 243 |
+
if end_tok <= start_tok:
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
span_slice = x[b, start_tok:end_tok, :]
|
| 247 |
+
|
| 248 |
+
if span_type == "formula":
|
| 249 |
+
# Extract formula from text if available
|
| 250 |
+
if text:
|
| 251 |
+
formula = text[b][start_tok*4:end_tok*4] # rough extraction
|
| 252 |
+
mol_emb = self.encode_molecule(formula)
|
| 253 |
+
else:
|
| 254 |
+
mol_emb = torch.randn(d_model, device=device)
|
| 255 |
+
|
| 256 |
+
# Add molecular embedding to first token
|
| 257 |
+
output[b, start_tok, :] += mol_emb
|
| 258 |
+
|
| 259 |
+
elif span_type == "amino_acid":
|
| 260 |
+
# Encode as amino acid sequence
|
| 261 |
+
# Simplified: treat each letter as amino acid
|
| 262 |
+
seq_len_span = end_tok - start_tok
|
| 263 |
+
aa_ids = torch.randint(0, 20, (seq_len_span,), device=device)
|
| 264 |
+
aa_emb = self.amino_embed(aa_ids) # (seq_len_span, d_model)
|
| 265 |
+
output[b, start_tok:end_tok, :] += aa_emb
|
| 266 |
+
|
| 267 |
+
elif span_type == "smiles":
|
| 268 |
+
# For SMILES, apply graph attention (simplified)
|
| 269 |
+
# Treat each character as a node
|
| 270 |
+
seq_len_span = end_tok - start_tok
|
| 271 |
+
if seq_len_span > 1:
|
| 272 |
+
# Self-attention over the span
|
| 273 |
+
attn_out, _ = self.mol_attention(
|
| 274 |
+
span_slice.unsqueeze(0),
|
| 275 |
+
span_slice.unsqueeze(0),
|
| 276 |
+
span_slice.unsqueeze(0),
|
| 277 |
+
)
|
| 278 |
+
output[b, start_tok:end_tok, :] += attn_out.squeeze(0)
|
| 279 |
+
|
| 280 |
+
return output
|
| 281 |
+
|
| 282 |
+
def compute_property_loss(
|
| 283 |
+
self,
|
| 284 |
+
x: torch.Tensor,
|
| 285 |
+
element_ids: torch.Tensor,
|
| 286 |
+
target_properties: torch.Tensor,
|
| 287 |
+
) -> torch.Tensor:
|
| 288 |
+
"""
|
| 289 |
+
Compute auxiliary loss for property prediction.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
x: Input tensor (batch, seq_len, d_model)
|
| 293 |
+
element_ids: Element IDs (batch, seq_len)
|
| 294 |
+
target_properties: Target property values (batch, seq_len, 12)
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
MSE loss for property prediction
|
| 298 |
+
"""
|
| 299 |
+
# Get element embeddings
|
| 300 |
+
elem_emb = self.element_embed(element_ids)
|
| 301 |
+
|
| 302 |
+
# Predict properties
|
| 303 |
+
pred_props = self.property_head(elem_emb)
|
| 304 |
+
|
| 305 |
+
# Compute loss
|
| 306 |
+
loss = F.mse_loss(pred_props, target_properties)
|
| 307 |
+
return loss
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def test_molecular_module():
|
| 311 |
+
"""Test MolecularModule."""
|
| 312 |
+
d_model = 512
|
| 313 |
+
batch_size = 2
|
| 314 |
+
seq_len = 128
|
| 315 |
+
|
| 316 |
+
module = MolecularModule(d_model)
|
| 317 |
+
|
| 318 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 319 |
+
text = [
|
| 320 |
+
"Water is H2O. The DNA sequence is ACGTACGTACGT.",
|
| 321 |
+
"Proteins are made of amino acids like ACDEFGH. Benzene is C6H6."
|
| 322 |
+
]
|
| 323 |
+
|
| 324 |
+
output = module(x, text=text)
|
| 325 |
+
print(f"Input shape: {x.shape}")
|
| 326 |
+
print(f"Output shape: {output.shape}")
|
| 327 |
+
assert output.shape == x.shape
|
| 328 |
+
|
| 329 |
+
print("MolecularModule test passed!")
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
if __name__ == "__main__":
|
| 333 |
+
test_molecular_module()
|
models/science_modules/numerical_module.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NumericalReasoningModule: Handles scientific numerical reasoning.
|
| 3 |
+
Digit-level number encoding, scientific notation, unit awareness.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import re
|
| 10 |
+
from typing import Optional, Tuple, List
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class NumericalReasoningModule(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Handles scientific numerical reasoning.
|
| 16 |
+
- Digit-level number encoding (each digit gets position-aware embedding)
|
| 17 |
+
- Scientific notation understanding (6.02 × 10²³)
|
| 18 |
+
- Unit awareness (meters, joules, moles, kelvin)
|
| 19 |
+
- Order of magnitude reasoning
|
| 20 |
+
- Significant figures tracking
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
d_model: int,
|
| 26 |
+
max_digits: int = 20,
|
| 27 |
+
num_units: int = 256,
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Initialize NumericalReasoningModule.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
d_model: Model dimension
|
| 34 |
+
max_digits: Maximum number of digits to encode
|
| 35 |
+
num_units: Number of unit types to embed
|
| 36 |
+
"""
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.d_model = d_model
|
| 39 |
+
self.max_digits = max_digits
|
| 40 |
+
|
| 41 |
+
# Digit embeddings (0-9)
|
| 42 |
+
self.digit_embed = nn.Embedding(10, 64)
|
| 43 |
+
|
| 44 |
+
# Position embeddings (ones, tens, hundreds...)
|
| 45 |
+
self.position_embed = nn.Embedding(max_digits, 64)
|
| 46 |
+
|
| 47 |
+
# Project digit+position to model dimension
|
| 48 |
+
self.number_proj = nn.Linear(128, d_model)
|
| 49 |
+
|
| 50 |
+
# Unit embedding (SI units + common scientific units)
|
| 51 |
+
self.unit_embed = nn.Embedding(num_units, d_model)
|
| 52 |
+
|
| 53 |
+
# Scientific notation handler
|
| 54 |
+
self.sci_notation = nn.Linear(d_model * 2, d_model)
|
| 55 |
+
|
| 56 |
+
# Magnitude embedding (powers of 10: -10 to +10)
|
| 57 |
+
self.magnitude_embed = nn.Embedding(21, d_model) # -10 to +10
|
| 58 |
+
|
| 59 |
+
# Initialize weights
|
| 60 |
+
self._initialize_weights()
|
| 61 |
+
|
| 62 |
+
def _initialize_weights(self):
|
| 63 |
+
"""Initialize weights."""
|
| 64 |
+
for module in [self.digit_embed, self.position_embed, self.number_proj,
|
| 65 |
+
self.unit_embed, self.sci_notation, self.magnitude_embed]:
|
| 66 |
+
if hasattr(module, 'weight'):
|
| 67 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 68 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 69 |
+
nn.init.zeros_(module.bias)
|
| 70 |
+
|
| 71 |
+
def encode_number(
|
| 72 |
+
self,
|
| 73 |
+
number_str: str,
|
| 74 |
+
device: torch.device,
|
| 75 |
+
) -> torch.Tensor:
|
| 76 |
+
"""
|
| 77 |
+
Encode a number string using digit-level encoding.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
number_str: String representation of number (e.g., "123.45e-6")
|
| 81 |
+
device: Torch device
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Number embedding (d_model,)
|
| 85 |
+
"""
|
| 86 |
+
# Extract digits (ignore decimal point, sign, exponent)
|
| 87 |
+
digits = [int(d) for d in re.findall(r'\d', number_str)]
|
| 88 |
+
if not digits:
|
| 89 |
+
digits = [0]
|
| 90 |
+
|
| 91 |
+
# Pad/truncate to max_digits
|
| 92 |
+
if len(digits) > self.max_digits:
|
| 93 |
+
digits = digits[:self.max_digits]
|
| 94 |
+
else:
|
| 95 |
+
digits = digits + [0] * (self.max_digits - len(digits))
|
| 96 |
+
|
| 97 |
+
digits_tensor = torch.tensor(digits, device=device) # (max_digits,)
|
| 98 |
+
positions = torch.arange(self.max_digits, device=device) # (max_digits,)
|
| 99 |
+
|
| 100 |
+
# Embed digits and positions
|
| 101 |
+
digit_emb = self.digit_embed(digits_tensor) # (max_digits, 64)
|
| 102 |
+
pos_emb = self.position_embed(positions) # (max_digits, 64)
|
| 103 |
+
|
| 104 |
+
# Concatenate and project
|
| 105 |
+
combined = torch.cat([digit_emb, pos_emb], dim=-1) # (max_digits, 128)
|
| 106 |
+
number_emb = self.number_proj(combined) # (max_digits, d_model)
|
| 107 |
+
|
| 108 |
+
# Mean pool over positions
|
| 109 |
+
return number_emb.mean(dim=0) # (d_model,)
|
| 110 |
+
|
| 111 |
+
def detect_numbers(
|
| 112 |
+
self,
|
| 113 |
+
text: str,
|
| 114 |
+
) -> List[Tuple[str, int, int, Optional[str]]]:
|
| 115 |
+
"""
|
| 116 |
+
Detect numbers in text with optional units and scientific notation.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
List of (number_str, start_char, end_char, unit_str)
|
| 120 |
+
"""
|
| 121 |
+
# Pattern: number with optional decimal, exponent, and unit
|
| 122 |
+
# Matches: 123, 123.45, 1.23e-4, 6.02×10²³, 100 m, 5.0 J/mol
|
| 123 |
+
pattern = r'(\d+(?:\.\d+)?(?:[eE][+-]?\d+)?(?:×10\^?[+-]?\d+)?)(?:\s*([a-zA-Z°%]+))?'
|
| 124 |
+
|
| 125 |
+
matches = []
|
| 126 |
+
for match in re.finditer(pattern, text):
|
| 127 |
+
number_str = match.group(1)
|
| 128 |
+
unit_str = match.group(2) if match.group(2) else None
|
| 129 |
+
matches.append((number_str, match.start(), match.end(), unit_str))
|
| 130 |
+
|
| 131 |
+
return matches
|
| 132 |
+
|
| 133 |
+
def forward(
|
| 134 |
+
self,
|
| 135 |
+
x: torch.Tensor,
|
| 136 |
+
text: Optional[List[str]] = None,
|
| 137 |
+
number_positions: Optional[List[List[Tuple[int, int, str]]]] = None,
|
| 138 |
+
) -> torch.Tensor:
|
| 139 |
+
"""
|
| 140 |
+
Forward pass through numerical reasoning module.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
x: Input tensor (batch, seq_len, d_model)
|
| 144 |
+
text: Optional original text strings
|
| 145 |
+
number_positions: Optional list of (start_token, end_token, number_str) per batch
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Numerical-enhanced representation (batch, seq_len, d_model)
|
| 149 |
+
"""
|
| 150 |
+
batch, seq_len, d_model = x.shape
|
| 151 |
+
device = x.device
|
| 152 |
+
|
| 153 |
+
# Detect numbers if text provided
|
| 154 |
+
if number_positions is None and text is not None:
|
| 155 |
+
number_positions = []
|
| 156 |
+
for b in range(batch):
|
| 157 |
+
numbers = self.detect_numbers(text[b])
|
| 158 |
+
# Convert char positions to token positions (approximate)
|
| 159 |
+
token_nums = []
|
| 160 |
+
for num_str, start_char, end_char, unit_str in numbers:
|
| 161 |
+
start_tok = max(0, start_char // 4)
|
| 162 |
+
end_tok = min(seq_len, end_char // 4 + 1)
|
| 163 |
+
token_nums.append((start_tok, end_tok, num_str, unit_str))
|
| 164 |
+
number_positions.append(token_nums)
|
| 165 |
+
|
| 166 |
+
# Enhance number spans
|
| 167 |
+
output = x.clone()
|
| 168 |
+
|
| 169 |
+
if number_positions:
|
| 170 |
+
for b in range(batch):
|
| 171 |
+
nums_b = number_positions[b] if b < len(number_positions) else []
|
| 172 |
+
|
| 173 |
+
for start_tok, end_tok, num_str, unit_str in nums_b:
|
| 174 |
+
if end_tok <= start_tok or start_tok >= seq_len:
|
| 175 |
+
continue
|
| 176 |
+
|
| 177 |
+
# Clamp to sequence bounds
|
| 178 |
+
start_tok = min(start_tok, seq_len - 1)
|
| 179 |
+
end_tok = min(end_tok, seq_len)
|
| 180 |
+
|
| 181 |
+
# Encode the number
|
| 182 |
+
number_emb = self.encode_number(num_str, device) # (d_model,)
|
| 183 |
+
|
| 184 |
+
# Add unit embedding if present
|
| 185 |
+
if unit_str:
|
| 186 |
+
# Simple hash-based unit ID (in practice would have unit vocab)
|
| 187 |
+
unit_id = hash(unit_str) % self.unit_embed.num_embeddings
|
| 188 |
+
unit_emb = self.unit_embed(torch.tensor(unit_id, device=device))
|
| 189 |
+
number_emb = number_emb + unit_emb
|
| 190 |
+
|
| 191 |
+
# Add magnitude embedding for scientific notation
|
| 192 |
+
if 'e' in num_str.lower() or '×10' in num_str:
|
| 193 |
+
# Extract exponent
|
| 194 |
+
exp_match = re.search(r'[eE]([+-]?\d+)|×10\^?([+-]?\d+)', num_str)
|
| 195 |
+
if exp_match:
|
| 196 |
+
exp = int(exp_match.group(1) or exp_match.group(2))
|
| 197 |
+
exp = max(-10, min(10, exp)) # Clamp to embedding range
|
| 198 |
+
magnitude_emb = self.magnitude_embed(torch.tensor(exp + 10, device=device))
|
| 199 |
+
number_emb = number_emb + magnitude_emb
|
| 200 |
+
|
| 201 |
+
# Add to the first token of the number span
|
| 202 |
+
output[b, start_tok, :] += number_emb
|
| 203 |
+
|
| 204 |
+
return output
|
| 205 |
+
|
| 206 |
+
def compute_numerical_loss(
|
| 207 |
+
self,
|
| 208 |
+
x: torch.Tensor,
|
| 209 |
+
number_mask: torch.Tensor,
|
| 210 |
+
target_values: torch.Tensor,
|
| 211 |
+
) -> torch.Tensor:
|
| 212 |
+
"""
|
| 213 |
+
Compute auxiliary loss for numerical reasoning.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
x: Input tensor (batch, seq_len, d_model)
|
| 217 |
+
number_mask: Mask for number tokens (batch, seq_len)
|
| 218 |
+
target_values: Target numeric values (batch, seq_len) or None
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
MSE loss for value prediction (simplified)
|
| 222 |
+
"""
|
| 223 |
+
# This is a simplified loss - in practice would have a value prediction head
|
| 224 |
+
# For now, return a small regularization loss on number embeddings
|
| 225 |
+
return 0.0
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def test_numerical_module():
|
| 229 |
+
"""Test NumericalReasoningModule."""
|
| 230 |
+
d_model = 512
|
| 231 |
+
batch_size = 2
|
| 232 |
+
seq_len = 128
|
| 233 |
+
|
| 234 |
+
module = NumericalReasoningModule(d_model)
|
| 235 |
+
|
| 236 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 237 |
+
text = [
|
| 238 |
+
"The speed of light is 2.998×10^8 m/s and Planck's constant is 6.626×10^-34 J·s.",
|
| 239 |
+
"Calculate: 123.45 + 67.89 = ? The answer is 191.34."
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
output = module(x, text=text)
|
| 243 |
+
print(f"Input shape: {x.shape}")
|
| 244 |
+
print(f"Output shape: {output.shape}")
|
| 245 |
+
assert output.shape == x.shape
|
| 246 |
+
|
| 247 |
+
print("NumericalReasoningModule test passed!")
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
if __name__ == "__main__":
|
| 251 |
+
test_numerical_module()
|
models/scigate_ffn.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SciGateFFN: Science-aware gated feed-forward network.
|
| 3 |
+
Learns to activate different FFN pathways based on science domain.
|
| 4 |
+
Uses hybrid routing: explicit domain tags preferred, fallback to learned classifier.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from typing import Optional, Tuple
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SciGateFFN(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Gated FFN with science domain routing.
|
| 16 |
+
Learns to activate different FFN pathways for different science domains.
|
| 17 |
+
Gate is conditioned on detected domain (math, chemistry, biology etc).
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
d_model: int,
|
| 23 |
+
expansion: int = 4,
|
| 24 |
+
num_domains: int = 7,
|
| 25 |
+
use_domain_tags: bool = True,
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
Initialize SciGateFFN.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
d_model: Model dimension
|
| 32 |
+
expansion: FFN expansion factor (default 4)
|
| 33 |
+
num_domains: Number of science domains (7)
|
| 34 |
+
use_domain_tags: Whether to use explicit domain tags for routing
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.d_model = d_model
|
| 38 |
+
self.expansion = expansion
|
| 39 |
+
self.num_domains = num_domains
|
| 40 |
+
self.use_domain_tags = use_domain_tags
|
| 41 |
+
|
| 42 |
+
hidden_dim = d_model * expansion
|
| 43 |
+
|
| 44 |
+
# Standard SwiGLU architecture: up_proj splits into two paths
|
| 45 |
+
self.up_proj = nn.Linear(d_model, hidden_dim * 2, bias=False)
|
| 46 |
+
self.down_proj = nn.Linear(hidden_dim, d_model, bias=False)
|
| 47 |
+
|
| 48 |
+
# Domain-specific scaling factors (learnable)
|
| 49 |
+
# Shape: (num_domains, hidden_dim)
|
| 50 |
+
self.domain_gate = nn.Linear(num_domains, hidden_dim, bias=True)
|
| 51 |
+
|
| 52 |
+
# Fallback domain classifier (when tags not present)
|
| 53 |
+
# Simple linear classifier based on sequence representation
|
| 54 |
+
self.fallback_classifier = nn.Sequential(
|
| 55 |
+
nn.Linear(d_model, d_model // 2),
|
| 56 |
+
nn.SiLU(),
|
| 57 |
+
nn.Linear(d_model // 2, num_domains),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Initialize weights
|
| 61 |
+
self._initialize_weights()
|
| 62 |
+
|
| 63 |
+
def _initialize_weights(self):
|
| 64 |
+
"""Initialize weights."""
|
| 65 |
+
for module in [self.up_proj, self.down_proj, self.domain_gate, self.fallback_classifier]:
|
| 66 |
+
if hasattr(module, 'weight'):
|
| 67 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 68 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 69 |
+
nn.init.zeros_(module.bias)
|
| 70 |
+
|
| 71 |
+
def get_domain_one_hot(
|
| 72 |
+
self,
|
| 73 |
+
domain_ids: Optional[torch.Tensor] = None,
|
| 74 |
+
domain_tags: Optional[torch.Tensor] = None,
|
| 75 |
+
hidden_states: Optional[torch.Tensor] = None,
|
| 76 |
+
) -> torch.Tensor:
|
| 77 |
+
"""
|
| 78 |
+
Get domain one-hot vector for routing.
|
| 79 |
+
|
| 80 |
+
Hybrid strategy:
|
| 81 |
+
1. If domain_tags provided (explicit [MATH], [CHEM] etc), use those
|
| 82 |
+
2. If domain_ids provided (from data loader), use those
|
| 83 |
+
3. Fallback: classify from hidden_states
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
domain_ids: Tensor of domain IDs (batch, seq_len) or (batch,)
|
| 87 |
+
domain_tags: Boolean mask for domain tags (batch, seq_len, num_domains)
|
| 88 |
+
hidden_states: Hidden states for fallback classification (batch, seq_len, d_model)
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
domain_one_hot: (batch, seq_len, num_domains)
|
| 92 |
+
"""
|
| 93 |
+
batch, seq_len, _ = hidden_states.shape if hidden_states is not None else (0, 0, 0)
|
| 94 |
+
|
| 95 |
+
if domain_tags is not None and domain_tags.any():
|
| 96 |
+
# Use explicit domain tags (one-hot already)
|
| 97 |
+
return domain_tags.float()
|
| 98 |
+
elif domain_ids is not None:
|
| 99 |
+
# Convert domain IDs to one-hot
|
| 100 |
+
if domain_ids.dim() == 1:
|
| 101 |
+
# Same domain for entire sequence
|
| 102 |
+
domain_one_hot = F.one_hot(domain_ids, num_classes=self.num_domains)
|
| 103 |
+
# Expand to sequence length
|
| 104 |
+
domain_one_hot = domain_one_hot.unsqueeze(1).expand(-1, seq_len, -1)
|
| 105 |
+
else:
|
| 106 |
+
# Per-token domain IDs
|
| 107 |
+
domain_one_hot = F.one_hot(domain_ids, num_classes=self.num_domains)
|
| 108 |
+
return domain_one_hot.float()
|
| 109 |
+
elif hidden_states is not None:
|
| 110 |
+
# Fallback: classify domain from hidden states
|
| 111 |
+
# Use mean pooling over sequence
|
| 112 |
+
pooled = hidden_states.mean(dim=1) # (batch, d_model)
|
| 113 |
+
domain_logits = self.fallback_classifier(pooled) # (batch, num_domains)
|
| 114 |
+
domain_probs = F.softmax(domain_logits, dim=-1)
|
| 115 |
+
# Expand to sequence length
|
| 116 |
+
return domain_probs.unsqueeze(1).expand(-1, seq_len, -1)
|
| 117 |
+
else:
|
| 118 |
+
# Uniform distribution (no domain info)
|
| 119 |
+
uniform = torch.ones(batch, seq_len, self.num_domains, device=hidden_states.device if hidden_states is not None else 'cpu')
|
| 120 |
+
return uniform / self.num_domains
|
| 121 |
+
|
| 122 |
+
def forward(
|
| 123 |
+
self,
|
| 124 |
+
x: torch.Tensor,
|
| 125 |
+
domain_ids: Optional[torch.Tensor] = None,
|
| 126 |
+
domain_tags: Optional[torch.Tensor] = None,
|
| 127 |
+
) -> torch.Tensor:
|
| 128 |
+
"""
|
| 129 |
+
Forward pass with domain-aware gating.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
x: Input tensor (batch, seq_len, d_model)
|
| 133 |
+
domain_ids: Optional domain IDs (batch,) or (batch, seq_len)
|
| 134 |
+
domain_tags: Optional domain tag mask (batch, seq_len, num_domains)
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Output tensor (batch, seq_len, d_model)
|
| 138 |
+
"""
|
| 139 |
+
batch, seq_len, d_model = x.shape
|
| 140 |
+
|
| 141 |
+
# Get domain routing weights
|
| 142 |
+
domain_weights = self.get_domain_one_hot(domain_ids, domain_tags, x)
|
| 143 |
+
# Shape: (batch, seq_len, num_domains)
|
| 144 |
+
|
| 145 |
+
# Project to hidden dimension
|
| 146 |
+
up = self.up_proj(x) # (batch, seq_len, hidden_dim * 2)
|
| 147 |
+
up1, up2 = up.chunk(2, dim=-1) # Each: (batch, seq_len, hidden_dim)
|
| 148 |
+
|
| 149 |
+
# Apply SwiGLU activation
|
| 150 |
+
hidden = up1 * F.silu(up2) # (batch, seq_len, hidden_dim)
|
| 151 |
+
|
| 152 |
+
# Apply domain-specific scaling
|
| 153 |
+
# domain_weights: (batch, seq_len, num_domains)
|
| 154 |
+
# self.domain_gate.weight: (hidden_dim, num_domains) - Linear weight shape
|
| 155 |
+
# einsum: (batch, seq_len, num_domains) * (hidden_dim, num_domains) -> (batch, seq_len, hidden_dim)
|
| 156 |
+
domain_scaling = torch.einsum(
|
| 157 |
+
"bsd,hd->bsh",
|
| 158 |
+
domain_weights,
|
| 159 |
+
self.domain_gate.weight # (hidden_dim, num_domains)
|
| 160 |
+
)
|
| 161 |
+
# domain_scaling: (batch, seq_len, hidden_dim)
|
| 162 |
+
|
| 163 |
+
# Apply domain scaling (multiplicative gating)
|
| 164 |
+
hidden = hidden * domain_scaling
|
| 165 |
+
|
| 166 |
+
# Project back to model dimension
|
| 167 |
+
output = self.down_proj(hidden)
|
| 168 |
+
|
| 169 |
+
return output
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def test_scigate_ffn():
|
| 173 |
+
"""Test SciGateFFN."""
|
| 174 |
+
batch_size = 2
|
| 175 |
+
seq_len = 128
|
| 176 |
+
d_model = 4096
|
| 177 |
+
num_domains = 7
|
| 178 |
+
|
| 179 |
+
ffn = SciGateFFN(d_model, expansion=4, num_domains=num_domains)
|
| 180 |
+
|
| 181 |
+
# Test with no domain info (fallback)
|
| 182 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 183 |
+
output = ffn(x)
|
| 184 |
+
print(f"Input shape: {x.shape}")
|
| 185 |
+
print(f"Output shape: {output.shape}")
|
| 186 |
+
assert output.shape == x.shape
|
| 187 |
+
|
| 188 |
+
# Test with explicit domain IDs
|
| 189 |
+
domain_ids = torch.randint(0, num_domains, (batch_size,))
|
| 190 |
+
output2 = ffn(x, domain_ids=domain_ids)
|
| 191 |
+
assert output2.shape == x.shape
|
| 192 |
+
|
| 193 |
+
# Test with domain tags
|
| 194 |
+
domain_tags = torch.zeros(batch_size, seq_len, num_domains)
|
| 195 |
+
domain_tags[:, :, 0] = 1.0 # All math
|
| 196 |
+
output3 = ffn(x, domain_tags=domain_tags)
|
| 197 |
+
assert output3.shape == x.shape
|
| 198 |
+
|
| 199 |
+
print("SciGateFFN test passed!")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
test_scigate_ffn()
|
models/ssm_layer.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VortexSSM: Selective State-Space Layer
|
| 3 |
+
Simplified Mamba-style SSM with input-dependent selection.
|
| 4 |
+
Provides O(n) complexity for long sequences, ideal for scientific documents.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from typing import Optional, Tuple
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class VortexSSM(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Selective state-space layer. Linear complexity O(n) vs attention's O(n²).
|
| 16 |
+
Handles long scientific documents efficiently with input-dependent selection.
|
| 17 |
+
|
| 18 |
+
Architecture based on Mamba but simplified for scientific reasoning tasks.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
d_model: int,
|
| 24 |
+
d_state: int = 16,
|
| 25 |
+
d_conv: int = 4,
|
| 26 |
+
expand: int = 2,
|
| 27 |
+
dt_rank: Optional[int] = None,
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Initialize VortexSSM.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
d_model: Model dimension
|
| 34 |
+
d_state: State dimension (default 16 for 7B, 32 for 13B)
|
| 35 |
+
d_conv: Convolution kernel size for local context
|
| 36 |
+
expand: Expansion factor for inner dimension
|
| 37 |
+
dt_rank: Rank for delta projection (if None, uses ceil(d_model/16))
|
| 38 |
+
"""
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.d_model = d_model
|
| 41 |
+
self.d_state = d_state
|
| 42 |
+
self.d_conv = d_conv
|
| 43 |
+
self.expand = expand
|
| 44 |
+
self.d_inner = d_model * expand
|
| 45 |
+
|
| 46 |
+
if dt_rank is None:
|
| 47 |
+
self.dt_rank = max(1, d_model // 16)
|
| 48 |
+
else:
|
| 49 |
+
self.dt_rank = dt_rank
|
| 50 |
+
|
| 51 |
+
# Input projection: splits into x and z pathways
|
| 52 |
+
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
|
| 53 |
+
|
| 54 |
+
# Convolution for local context before SSM
|
| 55 |
+
# Depthwise convolution for efficiency
|
| 56 |
+
self.conv1d = nn.Conv1d(
|
| 57 |
+
in_channels=self.d_inner,
|
| 58 |
+
out_channels=self.d_inner,
|
| 59 |
+
kernel_size=d_conv,
|
| 60 |
+
padding=d_conv - 1,
|
| 61 |
+
groups=self.d_inner,
|
| 62 |
+
bias=False,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# SSM parameter projections (input-dependent)
|
| 66 |
+
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + 2 * self.d_state, bias=False)
|
| 67 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
|
| 68 |
+
|
| 69 |
+
# State matrices (A is log-scale for stability)
|
| 70 |
+
# A is (d_inner, d_state)
|
| 71 |
+
self.A_log = nn.Parameter(torch.randn(self.d_inner, self.d_state))
|
| 72 |
+
self.D = nn.Parameter(torch.randn(self.d_inner))
|
| 73 |
+
|
| 74 |
+
# Output projection
|
| 75 |
+
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
|
| 76 |
+
|
| 77 |
+
# Initialize weights
|
| 78 |
+
self._initialize_weights()
|
| 79 |
+
|
| 80 |
+
def _initialize_weights(self):
|
| 81 |
+
"""Initialize weights properly."""
|
| 82 |
+
# Initialize A_log with negative values for stable discretization
|
| 83 |
+
nn.init.normal_(self.A_log, mean=-4.0, std=0.5)
|
| 84 |
+
nn.init.normal_(self.D, mean=0.0, std=0.1)
|
| 85 |
+
|
| 86 |
+
# Initialize projections with small values
|
| 87 |
+
for module in [self.in_proj, self.x_proj, self.dt_proj, self.conv1d, self.out_proj]:
|
| 88 |
+
if hasattr(module, 'weight'):
|
| 89 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 90 |
+
|
| 91 |
+
def forward(
|
| 92 |
+
self,
|
| 93 |
+
x: torch.Tensor,
|
| 94 |
+
state: Optional[torch.Tensor] = None,
|
| 95 |
+
return_state: bool = False,
|
| 96 |
+
) -> torch.Tensor:
|
| 97 |
+
"""
|
| 98 |
+
Forward pass through the SSM.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
x: Input tensor (batch, seq_len, d_model)
|
| 102 |
+
state: Previous hidden state (batch, d_inner, d_state)
|
| 103 |
+
return_state: If True, return (output, state)
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Output tensor (batch, seq_len, d_model) or tuple with state
|
| 107 |
+
"""
|
| 108 |
+
batch, seq_len, _ = x.shape
|
| 109 |
+
device = x.device
|
| 110 |
+
dtype = x.dtype
|
| 111 |
+
|
| 112 |
+
# Double-check d_inner matches A_log shape
|
| 113 |
+
d_inner = self.d_inner
|
| 114 |
+
|
| 115 |
+
# Project input to inner dimension
|
| 116 |
+
xz = self.in_proj(x) # (batch, seq_len, 2 * d_inner)
|
| 117 |
+
x, z = xz.chunk(2, dim=-1)
|
| 118 |
+
|
| 119 |
+
# Apply 1D convolution for local context
|
| 120 |
+
# Need to transpose for conv1d: (batch, d_inner, seq_len)
|
| 121 |
+
x_conv = x.transpose(1, 2)
|
| 122 |
+
x_conv = self.conv1d(x_conv)[..., :seq_len] # Trim padding
|
| 123 |
+
x = x_conv.transpose(1, 2)
|
| 124 |
+
|
| 125 |
+
# Discretization: compute delta, A, B parameters
|
| 126 |
+
# x_proj produces: delta (dt_rank), B (d_state), C (d_state)
|
| 127 |
+
x_dbl = self.x_proj(x) # (batch, seq_len, dt_rank + 2*d_state)
|
| 128 |
+
(delta, B, C) = torch.split(
|
| 129 |
+
x_dbl,
|
| 130 |
+
[self.dt_rank, self.d_state, self.d_state],
|
| 131 |
+
dim=-1,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Project delta
|
| 135 |
+
delta = self.dt_proj(delta) # (batch, seq_len, d_inner)
|
| 136 |
+
delta = F.softplus(delta)
|
| 137 |
+
|
| 138 |
+
# Compute discretized state recurrence
|
| 139 |
+
# Use scan operation for efficient sequential processing
|
| 140 |
+
if state is None:
|
| 141 |
+
state = torch.zeros(batch, d_inner, self.d_state, device=device, dtype=dtype)
|
| 142 |
+
|
| 143 |
+
# Sequential scan (can be optimized with CUDA kernel)
|
| 144 |
+
output = []
|
| 145 |
+
for t in range(seq_len):
|
| 146 |
+
x_t = x[:, t] # (batch, d_inner)
|
| 147 |
+
delta_t = delta[:, t] # (batch, d_inner)
|
| 148 |
+
B_t = B[:, t] # (batch, d_state)
|
| 149 |
+
C_t = C[:, t] # (batch, d_state)
|
| 150 |
+
|
| 151 |
+
# Discretize A
|
| 152 |
+
A_delta = torch.exp(self.A_log * delta_t.unsqueeze(-1)) # (batch, d_inner, d_state)
|
| 153 |
+
|
| 154 |
+
# State update: state = A_delta * state + B_t * x_t
|
| 155 |
+
# B_t needs to be (batch, d_state) -> (batch, d_inner, d_state) via broadcasting
|
| 156 |
+
state = A_delta * state + B_t.unsqueeze(1) * x_t.unsqueeze(-1)
|
| 157 |
+
|
| 158 |
+
# Output: y = C_t * state + D * x_t
|
| 159 |
+
y = (C_t.unsqueeze(1) * state).sum(dim=-1) + self.D * x_t
|
| 160 |
+
output.append(y)
|
| 161 |
+
|
| 162 |
+
output = torch.stack(output, dim=1) # (batch, seq_len, d_inner)
|
| 163 |
+
|
| 164 |
+
# Apply gating with z
|
| 165 |
+
output = output * F.silu(z)
|
| 166 |
+
|
| 167 |
+
# Project back to model dimension
|
| 168 |
+
output = self.out_proj(output)
|
| 169 |
+
|
| 170 |
+
if return_state:
|
| 171 |
+
return output, state
|
| 172 |
+
return output
|
| 173 |
+
|
| 174 |
+
def step(
|
| 175 |
+
self,
|
| 176 |
+
x: torch.Tensor,
|
| 177 |
+
state: torch.Tensor,
|
| 178 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 179 |
+
"""
|
| 180 |
+
Single-step inference for autoregressive decoding.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
x: Input at current step (batch, d_model)
|
| 184 |
+
state: Previous state (batch, d_inner, d_state)
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
output: (batch, d_model)
|
| 188 |
+
new_state: updated state
|
| 189 |
+
"""
|
| 190 |
+
batch, _ = x.shape
|
| 191 |
+
|
| 192 |
+
# Project input
|
| 193 |
+
xz = self.in_proj(x.unsqueeze(1)) # Add seq dim
|
| 194 |
+
x, z = xz.chunk(2, dim=-1)
|
| 195 |
+
x = x.squeeze(1)
|
| 196 |
+
z = z.squeeze(1)
|
| 197 |
+
|
| 198 |
+
# No convolution for single step (would need cache)
|
| 199 |
+
|
| 200 |
+
# Compute parameters
|
| 201 |
+
x_dbl = self.x_proj(x.unsqueeze(1)).squeeze(1)
|
| 202 |
+
delta, B, C = torch.split(
|
| 203 |
+
x_dbl,
|
| 204 |
+
[self.dt_rank, self.d_state, self.d_state],
|
| 205 |
+
dim=-1,
|
| 206 |
+
)
|
| 207 |
+
delta = self.dt_proj(delta)
|
| 208 |
+
delta = F.softplus(delta)
|
| 209 |
+
|
| 210 |
+
# Single step discretization
|
| 211 |
+
A_delta = torch.exp(self.A_log * delta.unsqueeze(-1))
|
| 212 |
+
state = A_delta * state + B.unsqueeze(1) * x.unsqueeze(-1)
|
| 213 |
+
y = (C.unsqueeze(1) * state).sum(dim=-1) + self.D * x
|
| 214 |
+
y = y * F.silu(z)
|
| 215 |
+
output = self.out_proj(y)
|
| 216 |
+
|
| 217 |
+
return output, state
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def test_vortex_ssm():
|
| 221 |
+
"""Test the VortexSSM layer."""
|
| 222 |
+
batch_size = 2
|
| 223 |
+
seq_len = 128
|
| 224 |
+
d_model = 4096
|
| 225 |
+
d_state = 16
|
| 226 |
+
|
| 227 |
+
ssm = VortexSSM(d_model, d_state=d_state)
|
| 228 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 229 |
+
|
| 230 |
+
# Forward pass
|
| 231 |
+
output = ssm(x)
|
| 232 |
+
print(f"Input shape: {x.shape}")
|
| 233 |
+
print(f"Output shape: {output.shape}")
|
| 234 |
+
assert output.shape == x.shape, f"Expected {x.shape}, got {output.shape}"
|
| 235 |
+
|
| 236 |
+
# Stateful forward
|
| 237 |
+
state = torch.zeros(batch_size, ssm.d_inner, d_state)
|
| 238 |
+
output2, new_state = ssm(x, state=state, return_state=True)
|
| 239 |
+
print(f"Stateful output shape: {output2.shape}")
|
| 240 |
+
print(f"State shape: {new_state.shape}")
|
| 241 |
+
|
| 242 |
+
# Single step
|
| 243 |
+
x_step = torch.randn(batch_size, d_model)
|
| 244 |
+
output_step, state_step = ssm.step(x_step, state)
|
| 245 |
+
print(f"Step output shape: {output_step.shape}")
|
| 246 |
+
print(f"Step state shape: {state_step.shape}")
|
| 247 |
+
|
| 248 |
+
print("VortexSSM test passed!")
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
test_vortex_ssm()
|
models/vortex_model.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VortexModel: Main model class combining SSM, attention, science modules, and SciGate FFN.
|
| 3 |
+
Implements two block types: SSM-only and attention+science+SciGate FFN.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from typing import Optional, Tuple, List, Dict
|
| 10 |
+
|
| 11 |
+
from .ssm_layer import VortexSSM
|
| 12 |
+
from .attention_layer import VortexLocalAttention
|
| 13 |
+
from .scigate_ffn import SciGateFFN
|
| 14 |
+
from .science_modules import (
|
| 15 |
+
EquationModule,
|
| 16 |
+
NumericalReasoningModule,
|
| 17 |
+
CitationModule,
|
| 18 |
+
MolecularModule,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class VortexBlock(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
Two types of blocks:
|
| 25 |
+
1. SSMBlock: only VortexSSM
|
| 26 |
+
2. AttentionBlock: VortexLocalAttention + ScienceModules + SciGateFFN
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
config: Dict,
|
| 32 |
+
is_ssm_block: bool = True,
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Initialize a Vortex block.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
config: Model configuration
|
| 39 |
+
is_ssm_block: If True, this is an SSM-only block; else attention+science+FFN
|
| 40 |
+
"""
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.config = config
|
| 43 |
+
self.is_ssm_block = is_ssm_block
|
| 44 |
+
self.d_model = config["d_model"]
|
| 45 |
+
|
| 46 |
+
if is_ssm_block:
|
| 47 |
+
# SSM-only block
|
| 48 |
+
self.ssm = VortexSSM(
|
| 49 |
+
d_model=config["d_model"],
|
| 50 |
+
d_state=config["d_state"],
|
| 51 |
+
d_conv=config["d_conv"],
|
| 52 |
+
)
|
| 53 |
+
self.norm = nn.LayerNorm(config["d_model"])
|
| 54 |
+
else:
|
| 55 |
+
# Attention + Science + FFN block
|
| 56 |
+
self.attn = VortexLocalAttention(
|
| 57 |
+
d_model=config["d_model"],
|
| 58 |
+
num_heads=config["num_heads"],
|
| 59 |
+
window_size=config["window_size"],
|
| 60 |
+
use_flash_attention=config.get("use_flash_attention", True),
|
| 61 |
+
)
|
| 62 |
+
self.attn_norm = nn.LayerNorm(config["d_model"])
|
| 63 |
+
|
| 64 |
+
# Science modules (enabled based on config flags)
|
| 65 |
+
self.equation_module = None
|
| 66 |
+
self.numerical_module = None
|
| 67 |
+
self.citation_module = None
|
| 68 |
+
self.molecular_module = None
|
| 69 |
+
|
| 70 |
+
if config.get("enable_equation_module", True):
|
| 71 |
+
self.equation_module = EquationModule(config["d_model"])
|
| 72 |
+
|
| 73 |
+
if config.get("enable_numerical_module", True):
|
| 74 |
+
self.numerical_module = NumericalReasoningModule(config["d_model"])
|
| 75 |
+
|
| 76 |
+
if config.get("enable_citation_module", True):
|
| 77 |
+
self.citation_module = CitationModule(config["d_model"])
|
| 78 |
+
|
| 79 |
+
if config.get("enable_molecular_module", True):
|
| 80 |
+
self.molecular_module = MolecularModule(config["d_model"])
|
| 81 |
+
|
| 82 |
+
# SciGate FFN
|
| 83 |
+
self.ffn = SciGateFFN(
|
| 84 |
+
d_model=config["d_model"],
|
| 85 |
+
expansion=config["ffn_expansion"],
|
| 86 |
+
num_domains=config["num_domains"],
|
| 87 |
+
)
|
| 88 |
+
self.ffn_norm = nn.LayerNorm(config["d_model"])
|
| 89 |
+
|
| 90 |
+
# Final layer norm for both block types
|
| 91 |
+
self.final_norm = nn.LayerNorm(config["d_model"])
|
| 92 |
+
|
| 93 |
+
def forward(
|
| 94 |
+
self,
|
| 95 |
+
x: torch.Tensor,
|
| 96 |
+
domain_ids: Optional[torch.Tensor] = None,
|
| 97 |
+
domain_tags: Optional[torch.Tensor] = None,
|
| 98 |
+
text: Optional[List[str]] = None,
|
| 99 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 100 |
+
) -> torch.Tensor:
|
| 101 |
+
"""
|
| 102 |
+
Forward pass through the block.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
x: Input tensor (batch, seq_len, d_model)
|
| 106 |
+
domain_ids: Optional domain IDs for SciGate FFN
|
| 107 |
+
domain_tags: Optional domain tag masks
|
| 108 |
+
text: Optional original text for science module span detection
|
| 109 |
+
attention_mask: Optional attention mask
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Output tensor (batch, seq_len, d_model)
|
| 113 |
+
"""
|
| 114 |
+
residual = x
|
| 115 |
+
|
| 116 |
+
if self.is_ssm_block:
|
| 117 |
+
# SSM-only pathway
|
| 118 |
+
x = self.norm(x)
|
| 119 |
+
x = self.ssm(x)
|
| 120 |
+
x = residual + x
|
| 121 |
+
x = self.final_norm(x)
|
| 122 |
+
else:
|
| 123 |
+
# Attention + Science + FFN pathway
|
| 124 |
+
# Attention
|
| 125 |
+
residual_attn = x
|
| 126 |
+
x = self.attn_norm(x)
|
| 127 |
+
global_mask = self._detect_global_tokens(x) if hasattr(self, '_detect_global_tokens') else None
|
| 128 |
+
x = self.attn(x, global_mask=global_mask, attention_mask=attention_mask)
|
| 129 |
+
x = residual_attn + x
|
| 130 |
+
|
| 131 |
+
# Science modules (applied sequentially)
|
| 132 |
+
if self.equation_module is not None:
|
| 133 |
+
x = x + self.equation_module(x, text=text)
|
| 134 |
+
|
| 135 |
+
if self.numerical_module is not None:
|
| 136 |
+
x = x + self.numerical_module(x, text=text)
|
| 137 |
+
|
| 138 |
+
if self.citation_module is not None:
|
| 139 |
+
x_cited, _ = self.citation_module(x, text=text)
|
| 140 |
+
x = x + x_cited
|
| 141 |
+
|
| 142 |
+
if self.molecular_module is not None:
|
| 143 |
+
x = x + self.molecular_module(x, text=text)
|
| 144 |
+
|
| 145 |
+
# SciGate FFN
|
| 146 |
+
residual_ffn = x
|
| 147 |
+
x = self.ffn_norm(x)
|
| 148 |
+
x = self.ffn(x, domain_ids=domain_ids, domain_tags=domain_tags)
|
| 149 |
+
x = residual_ffn + x
|
| 150 |
+
|
| 151 |
+
x = self.final_norm(x)
|
| 152 |
+
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
def _detect_global_tokens(self, x: torch.Tensor) -> torch.Tensor:
|
| 156 |
+
"""
|
| 157 |
+
Detect global tokens that should attend across the entire sequence.
|
| 158 |
+
Global tokens are those with special domain tags or high norm.
|
| 159 |
+
"""
|
| 160 |
+
# Simple heuristic: tokens with large L2 norm are likely special
|
| 161 |
+
norms = torch.norm(x, dim=-1) # (batch, seq_len)
|
| 162 |
+
threshold = torch.quantile(norms, 0.95, dim=-1, keepdim=True)
|
| 163 |
+
global_mask = norms > threshold
|
| 164 |
+
|
| 165 |
+
return global_mask
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class VortexModel(nn.Module):
|
| 169 |
+
"""
|
| 170 |
+
Main Vortex model combining SSM and attention blocks.
|
| 171 |
+
Supports both 7B and 13B configurations.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
config: Dict,
|
| 177 |
+
):
|
| 178 |
+
"""
|
| 179 |
+
Initialize VortexModel.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
config: Model configuration (from vortex_7b_config.py or vortex_13b_config.py)
|
| 183 |
+
"""
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.config = config
|
| 186 |
+
|
| 187 |
+
# Token embedding
|
| 188 |
+
self.embed_tokens = nn.Embedding(config["vocab_size"], config["d_model"])
|
| 189 |
+
|
| 190 |
+
# Build blocks according to layer ratio
|
| 191 |
+
self.blocks = nn.ModuleList()
|
| 192 |
+
self._build_blocks()
|
| 193 |
+
|
| 194 |
+
# Final layer norm
|
| 195 |
+
self.ln_f = nn.LayerNorm(config["d_model"])
|
| 196 |
+
|
| 197 |
+
# Output projection (weights will be tied by HuggingFace if config.tie_word_embeddings=True)
|
| 198 |
+
self.lm_head = nn.Linear(config["d_model"], config["vocab_size"], bias=False)
|
| 199 |
+
|
| 200 |
+
# Initialize weights
|
| 201 |
+
self._initialize_weights()
|
| 202 |
+
|
| 203 |
+
def _build_blocks(self):
|
| 204 |
+
"""Build the sequence of SSM and attention blocks."""
|
| 205 |
+
num_layers = self.config["num_layers"]
|
| 206 |
+
ssm_ratio = self.config["ssm_ratio"]
|
| 207 |
+
|
| 208 |
+
# Calculate number of each block type
|
| 209 |
+
num_ssm_blocks = int(num_layers * ssm_ratio)
|
| 210 |
+
num_attn_blocks = num_layers - num_ssm_blocks
|
| 211 |
+
|
| 212 |
+
# Determine block pattern
|
| 213 |
+
if ssm_ratio == 0.6: # 7B pattern: SSM, SSM, Attn, SSM, SSM, Attn...
|
| 214 |
+
pattern = [0, 0, 1] # 0=SSM, 1=Attn
|
| 215 |
+
# Repeat pattern and fill remaining
|
| 216 |
+
blocks = []
|
| 217 |
+
while len(blocks) < num_layers:
|
| 218 |
+
blocks.extend(pattern[:min(len(pattern), num_layers - len(blocks))])
|
| 219 |
+
else: # 13B pattern: SSM, Attn, SSM, Attn...
|
| 220 |
+
pattern = [0, 1]
|
| 221 |
+
blocks = []
|
| 222 |
+
while len(blocks) < num_layers:
|
| 223 |
+
blocks.extend(pattern[:min(len(pattern), num_layers - len(blocks))])
|
| 224 |
+
|
| 225 |
+
# Ensure exact count
|
| 226 |
+
blocks = blocks[:num_layers]
|
| 227 |
+
assert len(blocks) == num_layers
|
| 228 |
+
|
| 229 |
+
# Create blocks
|
| 230 |
+
for is_attn in blocks:
|
| 231 |
+
block = VortexBlock(
|
| 232 |
+
config=self.config,
|
| 233 |
+
is_ssm_block=not is_attn,
|
| 234 |
+
)
|
| 235 |
+
self.blocks.append(block)
|
| 236 |
+
|
| 237 |
+
print(f"Built {num_layers} layers: {num_ssm_blocks} SSM, {num_attn_blocks} Attention")
|
| 238 |
+
|
| 239 |
+
def _initialize_weights(self):
|
| 240 |
+
"""Initialize weights."""
|
| 241 |
+
nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=0.02)
|
| 242 |
+
for block in self.blocks:
|
| 243 |
+
if hasattr(block, 'ssm'):
|
| 244 |
+
block.ssm._initialize_weights()
|
| 245 |
+
if hasattr(block, 'attn'):
|
| 246 |
+
block.attn._initialize_weights()
|
| 247 |
+
if hasattr(block, 'ffn'):
|
| 248 |
+
block.ffn._initialize_weights()
|
| 249 |
+
|
| 250 |
+
def forward(
|
| 251 |
+
self,
|
| 252 |
+
input_ids: torch.Tensor,
|
| 253 |
+
domain_ids: Optional[torch.Tensor] = None,
|
| 254 |
+
domain_tags: Optional[torch.Tensor] = None,
|
| 255 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 256 |
+
text: Optional[List[str]] = None,
|
| 257 |
+
return_dict: bool = True,
|
| 258 |
+
) -> torch.Tensor:
|
| 259 |
+
"""
|
| 260 |
+
Forward pass through the model.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
input_ids: Token IDs (batch, seq_len)
|
| 264 |
+
domain_ids: Optional domain IDs
|
| 265 |
+
domain_tags: Optional domain tag masks
|
| 266 |
+
attention_mask: Optional attention mask (batch, seq_len)
|
| 267 |
+
text: Optional original text for science modules
|
| 268 |
+
return_dict: Whether to return dict (always returns tensor for now)
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
Logits (batch, seq_len, vocab_size)
|
| 272 |
+
"""
|
| 273 |
+
# Embed tokens
|
| 274 |
+
x = self.embed_tokens(input_ids)
|
| 275 |
+
|
| 276 |
+
# Pass through blocks
|
| 277 |
+
for block in self.blocks:
|
| 278 |
+
x = block(
|
| 279 |
+
x,
|
| 280 |
+
domain_ids=domain_ids,
|
| 281 |
+
domain_tags=domain_tags,
|
| 282 |
+
text=text,
|
| 283 |
+
attention_mask=attention_mask,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Final norm
|
| 287 |
+
x = self.ln_f(x)
|
| 288 |
+
|
| 289 |
+
# Project to vocabulary
|
| 290 |
+
logits = self.lm_head(x)
|
| 291 |
+
|
| 292 |
+
if return_dict:
|
| 293 |
+
return {"logits": logits, "last_hidden_state": x}
|
| 294 |
+
return logits
|
| 295 |
+
|
| 296 |
+
def get_num_params(self) -> int:
|
| 297 |
+
"""Get total number of parameters."""
|
| 298 |
+
return sum(p.numel() for p in self.parameters())
|
| 299 |
+
|
| 300 |
+
def get_trainable_params(self) -> int:
|
| 301 |
+
"""Get number of trainable parameters."""
|
| 302 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 303 |
+
|
| 304 |
+
def estimate_memory_usage(
|
| 305 |
+
self,
|
| 306 |
+
batch_size: int,
|
| 307 |
+
seq_len: int,
|
| 308 |
+
use_gradient_checkpointing: bool = False,
|
| 309 |
+
) -> Dict[str, float]:
|
| 310 |
+
"""
|
| 311 |
+
Estimate memory usage for a given batch size and sequence length.
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
Dictionary with memory estimates in GB
|
| 315 |
+
"""
|
| 316 |
+
params = self.get_num_params()
|
| 317 |
+
param_bytes = params * 2 # Assuming bfloat16
|
| 318 |
+
|
| 319 |
+
# Activation memory (rough estimate)
|
| 320 |
+
# Each layer: activations ~ batch * seq_len * d_model * 2
|
| 321 |
+
activations_per_layer = batch_size * seq_len * self.config["d_model"] * 2
|
| 322 |
+
total_activations = activations_per_layer * self.config["num_layers"]
|
| 323 |
+
|
| 324 |
+
# Gradients (same size as parameters)
|
| 325 |
+
gradients = param_bytes
|
| 326 |
+
|
| 327 |
+
# Optimizer states (AdamW: 2x parameters)
|
| 328 |
+
optimizer_states = params * 2 * 2
|
| 329 |
+
|
| 330 |
+
total_memory = (param_bytes + total_activations + gradients + optimizer_states) / 1e9
|
| 331 |
+
|
| 332 |
+
return {
|
| 333 |
+
"parameters_gb": param_bytes / 1e9,
|
| 334 |
+
"activations_gb": total_activations / 1e9,
|
| 335 |
+
"gradients_gb": gradients / 1e9,
|
| 336 |
+
"optimizer_states_gb": optimizer_states / 1e9,
|
| 337 |
+
"total_gb": total_memory,
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def test_vortex_model():
|
| 342 |
+
"""Test the VortexModel."""
|
| 343 |
+
from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| 344 |
+
|
| 345 |
+
config = VORTEX_7B_CONFIG.copy()
|
| 346 |
+
# Reduce size for testing
|
| 347 |
+
config["d_model"] = 512
|
| 348 |
+
config["num_layers"] = 4
|
| 349 |
+
config["num_heads"] = 8
|
| 350 |
+
config["vocab_size"] = 1000
|
| 351 |
+
|
| 352 |
+
model = VortexModel(config)
|
| 353 |
+
|
| 354 |
+
batch_size = 2
|
| 355 |
+
seq_len = 128
|
| 356 |
+
input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
|
| 357 |
+
|
| 358 |
+
# Forward pass
|
| 359 |
+
output = model(input_ids)
|
| 360 |
+
logits = output["logits"]
|
| 361 |
+
|
| 362 |
+
print(f"Model parameters: {model.get_num_params():,}")
|
| 363 |
+
print(f"Input shape: {input_ids.shape}")
|
| 364 |
+
print(f"Logits shape: {logits.shape}")
|
| 365 |
+
assert logits.shape == (batch_size, seq_len, config["vocab_size"])
|
| 366 |
+
|
| 367 |
+
# Memory estimate
|
| 368 |
+
mem = model.estimate_memory_usage(batch_size, seq_len)
|
| 369 |
+
print(f"Memory estimate for batch={batch_size}, seq_len={seq_len}:")
|
| 370 |
+
for k, v in mem.items():
|
| 371 |
+
print(f" {k}: {v:.2f} GB")
|
| 372 |
+
|
| 373 |
+
print("VortexModel test passed!")
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
if __name__ == "__main__":
|
| 377 |
+
test_vortex_model()
|
mps_optimize.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MPS optimizations for Vortex model on Apple Silicon.
|
| 3 |
+
Uses PyTorch MPS backend with MPS-compatible ops only.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from typing import Optional, Dict, Any
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def optimize_for_mps(
|
| 12 |
+
model: nn.Module,
|
| 13 |
+
config: Dict,
|
| 14 |
+
use_sdpa: bool = True,
|
| 15 |
+
) -> nn.Module:
|
| 16 |
+
"""
|
| 17 |
+
Apply MPS optimizations to model.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
model: VortexModel
|
| 21 |
+
config: Model config
|
| 22 |
+
use_sdpa: Use PyTorch scaled dot product attention (MPS compatible)
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Optimized model
|
| 26 |
+
"""
|
| 27 |
+
device = torch.device("mps")
|
| 28 |
+
|
| 29 |
+
# Move to MPS
|
| 30 |
+
model = model.to(device)
|
| 31 |
+
|
| 32 |
+
# Set dtype - MPS supports float32 and float16 (bfloat16 limited)
|
| 33 |
+
dtype_str = config.get("dtype", "bfloat16")
|
| 34 |
+
if dtype_str == "bfloat16":
|
| 35 |
+
# MPS has limited bfloat16 support, use float16
|
| 36 |
+
dtype = torch.float16
|
| 37 |
+
else:
|
| 38 |
+
dtype = torch.float32
|
| 39 |
+
|
| 40 |
+
model = model.to(dtype)
|
| 41 |
+
|
| 42 |
+
# Replace Flash Attention with standard SDPA
|
| 43 |
+
if use_sdpa:
|
| 44 |
+
model = _apply_sdpa(model)
|
| 45 |
+
print("Applied PyTorch SDPA for MPS")
|
| 46 |
+
|
| 47 |
+
return model
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _apply_sdpa(model: nn.Module) -> nn.Module:
|
| 51 |
+
"""
|
| 52 |
+
Replace custom attention with PyTorch SDPA.
|
| 53 |
+
SDPA is optimized for MPS backend.
|
| 54 |
+
"""
|
| 55 |
+
for name, module in model.named_modules():
|
| 56 |
+
if hasattr(module, 'attn') and hasattr(module.attn, 'forward_optimized'):
|
| 57 |
+
# Use the SDPA path
|
| 58 |
+
original_forward = module.attn.forward
|
| 59 |
+
|
| 60 |
+
def sdpa_forward(self, x, *args, **kwargs):
|
| 61 |
+
return self._standard_attention(x, kwargs.get('attention_mask'))
|
| 62 |
+
|
| 63 |
+
module.attn.forward = sdpa_forward.__get__(module.attn, type(module.attn))
|
| 64 |
+
|
| 65 |
+
return model
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_mps_memory_usage() -> Dict[str, float]:
|
| 69 |
+
"""Get current MPS memory usage in GB."""
|
| 70 |
+
if not torch.backends.mps.is_available():
|
| 71 |
+
return {"error": "MPS not available"}
|
| 72 |
+
|
| 73 |
+
# MPS doesn't have direct memory query, use unified memory
|
| 74 |
+
import psutil
|
| 75 |
+
process = psutil.Process()
|
| 76 |
+
memory_info = process.memory_info()
|
| 77 |
+
|
| 78 |
+
return {
|
| 79 |
+
"rss_gb": memory_info.rss / 1e9, # Resident set size
|
| 80 |
+
"vms_gb": memory_info.vms / 1e9, # Virtual memory size
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def profile_model_mps(
|
| 85 |
+
model: nn.Module,
|
| 86 |
+
input_ids: torch.Tensor,
|
| 87 |
+
num_warmup: int = 10,
|
| 88 |
+
num_runs: int = 50,
|
| 89 |
+
) -> Dict[str, float]:
|
| 90 |
+
"""
|
| 91 |
+
Profile model performance on MPS.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
model: Model to profile
|
| 95 |
+
input_ids: Example input
|
| 96 |
+
num_warmup: Number of warmup runs
|
| 97 |
+
num_runs: Number of profiling runs
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Dictionary with timing statistics
|
| 101 |
+
"""
|
| 102 |
+
model.eval()
|
| 103 |
+
device = next(model.parameters()).device
|
| 104 |
+
input_ids = input_ids.to(device)
|
| 105 |
+
|
| 106 |
+
# Warmup
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
for _ in range(num_warmup):
|
| 109 |
+
_ = model(input_ids)
|
| 110 |
+
# MPS is async, need to wait
|
| 111 |
+
if device.type == "mps":
|
| 112 |
+
torch.mps.synchronize()
|
| 113 |
+
|
| 114 |
+
# Profile
|
| 115 |
+
if device.type == "mps":
|
| 116 |
+
torch.mps.synchronize()
|
| 117 |
+
import time
|
| 118 |
+
start = time.time()
|
| 119 |
+
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
for _ in range(num_runs):
|
| 122 |
+
_ = model(input_ids)
|
| 123 |
+
if device.type == "mps":
|
| 124 |
+
torch.mps.synchronize()
|
| 125 |
+
|
| 126 |
+
elapsed = time.time() - start
|
| 127 |
+
|
| 128 |
+
avg_time = elapsed / num_runs
|
| 129 |
+
tokens_per_sec = input_ids.shape[1] / avg_time
|
| 130 |
+
|
| 131 |
+
return {
|
| 132 |
+
"avg_time_sec": avg_time,
|
| 133 |
+
"tokens_per_sec": tokens_per_sec,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def test_mps_optimize():
|
| 138 |
+
"""Test MPS optimizations."""
|
| 139 |
+
if not torch.backends.mps.is_available():
|
| 140 |
+
print("MPS not available, skipping test")
|
| 141 |
+
return
|
| 142 |
+
|
| 143 |
+
from models.vortex_model import VortexModel
|
| 144 |
+
from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| 145 |
+
|
| 146 |
+
config = VORTEX_7B_CONFIG.copy()
|
| 147 |
+
config["d_model"] = 512
|
| 148 |
+
config["num_layers"] = 2
|
| 149 |
+
config["num_heads"] = 8
|
| 150 |
+
config["vocab_size"] = 1000
|
| 151 |
+
|
| 152 |
+
model = VortexModel(config)
|
| 153 |
+
print(f"Model parameters: {model.get_num_params():,}")
|
| 154 |
+
|
| 155 |
+
# Optimize for MPS
|
| 156 |
+
model = optimize_for_mps(model, config, use_sdpa=True)
|
| 157 |
+
|
| 158 |
+
# Test forward
|
| 159 |
+
batch_size = 2
|
| 160 |
+
seq_len = 128
|
| 161 |
+
input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)).to("mps")
|
| 162 |
+
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
output = model(input_ids)
|
| 165 |
+
logits = output["logits"]
|
| 166 |
+
|
| 167 |
+
print(f"Output shape: {logits.shape}")
|
| 168 |
+
print("MPS optimize test passed!")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
if __name__ == "__main__":
|
| 172 |
+
test_mps_optimize()
|
push_to_hf.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
from huggingface_hub import HfApi, create_repo
|
| 5 |
+
|
| 6 |
+
def push_to_hf(repo_id, token=None, private=False):
|
| 7 |
+
api = HfApi(token=token)
|
| 8 |
+
|
| 9 |
+
# Create repo if it doesn't exist
|
| 10 |
+
try:
|
| 11 |
+
create_repo(repo_id, repo_type="model", private=private, exist_ok=True)
|
| 12 |
+
print(f"Repository {repo_id} ready")
|
| 13 |
+
except Exception as e:
|
| 14 |
+
print(f"Repo creation note: {e}")
|
| 15 |
+
|
| 16 |
+
# Upload all files in current directory
|
| 17 |
+
cwd = os.getcwd()
|
| 18 |
+
print(f"Uploading from {cwd} to {repo_id}...")
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
api.upload_folder(
|
| 22 |
+
folder_path=cwd,
|
| 23 |
+
repo_id=repo_id,
|
| 24 |
+
repo_type="model",
|
| 25 |
+
commit_message="Upload Vortex model"
|
| 26 |
+
)
|
| 27 |
+
print(f"Successfully uploaded to {repo_id}")
|
| 28 |
+
except Exception as e:
|
| 29 |
+
print(f"Upload failed: {e}")
|
| 30 |
+
raise
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
parser = argparse.ArgumentParser()
|
| 34 |
+
parser.add_argument("--repo_id", type=str, required=True, help="HuggingFace repo ID")
|
| 35 |
+
parser.add_argument("--token", type=str, help="HuggingFace token (optional if logged in)")
|
| 36 |
+
parser.add_argument("--private", action="store_true", help="Make repository private")
|
| 37 |
+
args = parser.parse_args()
|
| 38 |
+
|
| 39 |
+
push_to_hf(args.repo_id, args.token, args.private)
|
requirements.txt
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies
|
| 2 |
+
torch>=2.2.0
|
| 3 |
+
transformers>=4.40.0
|
| 4 |
+
accelerate>=0.30.0
|
| 5 |
+
datasets>=2.18.0
|
| 6 |
+
tokenizers>=0.19.0
|
| 7 |
+
|
| 8 |
+
# Quantization
|
| 9 |
+
bitsandbytes>=0.43.0
|
| 10 |
+
|
| 11 |
+
# Flash Attention (CUDA only)
|
| 12 |
+
flash-attn>=2.5.0
|
| 13 |
+
|
| 14 |
+
# Scientific computing
|
| 15 |
+
numpy>=1.26.0
|
| 16 |
+
scipy>=1.12.0
|
| 17 |
+
scikit-learn>=1.4.0
|
| 18 |
+
|
| 19 |
+
# Chemistry/Biology
|
| 20 |
+
rdkit>=2023.9.0
|
| 21 |
+
pubchempy>=1.0.4
|
| 22 |
+
|
| 23 |
+
# Web scraping
|
| 24 |
+
arxiv>=2.1.0
|
| 25 |
+
beautifulsoup4>=4.12.0
|
| 26 |
+
requests>=2.31.0
|
| 27 |
+
|
| 28 |
+
# Data processing
|
| 29 |
+
pandas>=2.0.0
|
| 30 |
+
pyarrow>=14.0.0
|
| 31 |
+
|
| 32 |
+
# LaTeX parsing
|
| 33 |
+
pylatexenc>=2.10
|
| 34 |
+
|
| 35 |
+
# Deduplication
|
| 36 |
+
minhash>=0.1.0
|
| 37 |
+
|
| 38 |
+
# Utilities
|
| 39 |
+
tqdm>=4.65.0
|
| 40 |
+
psutil>=5.9.0
|
| 41 |
+
jsonlines>=3.1.0
|
| 42 |
+
|
| 43 |
+
# Optional: wandb for logging
|
| 44 |
+
# wandb>=0.16.0
|
| 45 |
+
|
| 46 |
+
# Development/testing
|
| 47 |
+
pytest>=7.0.0
|
| 48 |
+
black>=23.0.0
|
| 49 |
+
flake8>=6.0.0
|
| 50 |
+
mypy>=1.0.0
|
science_bench.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Science benchmarks for Vortex model.
|
| 3 |
+
Evaluates performance across 7 science domains.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from typing import Dict, List, Tuple
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class BenchmarkResult:
|
| 13 |
+
"""Results from a benchmark."""
|
| 14 |
+
domain: str
|
| 15 |
+
accuracy: float
|
| 16 |
+
total_questions: int
|
| 17 |
+
correct_answers: int
|
| 18 |
+
details: List[Dict]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ScienceBenchmark:
|
| 22 |
+
"""
|
| 23 |
+
Base class for science benchmarks.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, name: str, domain: str):
|
| 27 |
+
self.name = name
|
| 28 |
+
self.domain = domain
|
| 29 |
+
|
| 30 |
+
def load_questions(self) -> List[Dict]:
|
| 31 |
+
"""Load benchmark questions."""
|
| 32 |
+
raise NotImplementedError
|
| 33 |
+
|
| 34 |
+
def evaluate(
|
| 35 |
+
self,
|
| 36 |
+
model,
|
| 37 |
+
tokenizer,
|
| 38 |
+
device: torch.device,
|
| 39 |
+
max_samples: int = 100,
|
| 40 |
+
) -> BenchmarkResult:
|
| 41 |
+
"""
|
| 42 |
+
Evaluate model on benchmark.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
model: Vortex model
|
| 46 |
+
tokenizer: Tokenizer
|
| 47 |
+
device: Torch device
|
| 48 |
+
max_samples: Maximum samples to evaluate
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
BenchmarkResult
|
| 52 |
+
"""
|
| 53 |
+
questions = self.load_questions()[:max_samples]
|
| 54 |
+
correct = 0
|
| 55 |
+
|
| 56 |
+
details = []
|
| 57 |
+
for q in questions:
|
| 58 |
+
# Format prompt
|
| 59 |
+
prompt = self.format_prompt(q)
|
| 60 |
+
# Tokenize
|
| 61 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 62 |
+
|
| 63 |
+
# Generate answer
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
outputs = model.generate(
|
| 66 |
+
**inputs,
|
| 67 |
+
max_new_tokens=50,
|
| 68 |
+
temperature=0.0, # Greedy
|
| 69 |
+
do_sample=False,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Decode
|
| 73 |
+
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 74 |
+
answer = self.extract_answer(generated)
|
| 75 |
+
|
| 76 |
+
# Check correctness
|
| 77 |
+
is_correct = self.check_answer(answer, q["answer"])
|
| 78 |
+
if is_correct:
|
| 79 |
+
correct += 1
|
| 80 |
+
|
| 81 |
+
details.append({
|
| 82 |
+
"question": q["question"],
|
| 83 |
+
"expected": q["answer"],
|
| 84 |
+
"generated": answer,
|
| 85 |
+
"correct": is_correct,
|
| 86 |
+
})
|
| 87 |
+
|
| 88 |
+
accuracy = correct / len(questions) if questions else 0.0
|
| 89 |
+
return BenchmarkResult(
|
| 90 |
+
domain=self.domain,
|
| 91 |
+
accuracy=accuracy,
|
| 92 |
+
total_questions=len(questions),
|
| 93 |
+
correct_answers=correct,
|
| 94 |
+
details=details,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def format_prompt(self, question: Dict) -> str:
|
| 98 |
+
"""Format question into prompt."""
|
| 99 |
+
raise NotImplementedError
|
| 100 |
+
|
| 101 |
+
def extract_answer(self, text: str) -> str:
|
| 102 |
+
"""Extract answer from generated text."""
|
| 103 |
+
raise NotImplementedError
|
| 104 |
+
|
| 105 |
+
def check_answer(self, predicted: str, expected: str) -> bool:
|
| 106 |
+
"""Check if predicted answer matches expected."""
|
| 107 |
+
raise NotImplementedError
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class PhysicsBenchmark(ScienceBenchmark):
|
| 111 |
+
"""Physics benchmark (Feynman Questions style)."""
|
| 112 |
+
|
| 113 |
+
def __init__(self):
|
| 114 |
+
super().__init__("physics_benchmark", "physics")
|
| 115 |
+
|
| 116 |
+
def load_questions(self) -> List[Dict]:
|
| 117 |
+
# Placeholder - would load from dataset
|
| 118 |
+
return [
|
| 119 |
+
{
|
| 120 |
+
"question": "What is the formula for kinetic energy?",
|
| 121 |
+
"answer": "KE = 1/2 mv^2",
|
| 122 |
+
"type": "formula",
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"question": "Explain Newton's first law of motion.",
|
| 126 |
+
"answer": "An object at rest stays at rest unless acted upon by a force.",
|
| 127 |
+
"type": "conceptual",
|
| 128 |
+
},
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
def format_prompt(self, question: Dict) -> str:
|
| 132 |
+
return f"Question: {question['question']}\nAnswer:"
|
| 133 |
+
|
| 134 |
+
def extract_answer(self, text: str) -> str:
|
| 135 |
+
# Extract after "Answer:"
|
| 136 |
+
if "Answer:" in text:
|
| 137 |
+
return text.split("Answer:")[-1].strip()
|
| 138 |
+
return text.strip()
|
| 139 |
+
|
| 140 |
+
def check_answer(self, predicted: str, expected: str) -> bool:
|
| 141 |
+
# Simple string match (would use more sophisticated in practice)
|
| 142 |
+
pred_lower = predicted.lower()
|
| 143 |
+
exp_lower = expected.lower()
|
| 144 |
+
return exp_lower in pred_lower or pred_lower in exp_lower
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class MathBenchmark(ScienceBenchmark):
|
| 148 |
+
"""Math benchmark (MATH dataset style)."""
|
| 149 |
+
|
| 150 |
+
def __init__(self):
|
| 151 |
+
super().__init__("math_benchmark", "math")
|
| 152 |
+
|
| 153 |
+
def load_questions(self) -> List[Dict]:
|
| 154 |
+
return [
|
| 155 |
+
{
|
| 156 |
+
"question": "Solve for x: 2x + 5 = 15",
|
| 157 |
+
"answer": "x = 5",
|
| 158 |
+
"type": "algebra",
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
"question": "What is the derivative of x^2?",
|
| 162 |
+
"answer": "2x",
|
| 163 |
+
"type": "calculus",
|
| 164 |
+
},
|
| 165 |
+
]
|
| 166 |
+
|
| 167 |
+
def format_prompt(self, question: Dict) -> str:
|
| 168 |
+
return f"Problem: {question['question']}\nSolution:"
|
| 169 |
+
|
| 170 |
+
def extract_answer(self, text: str) -> str:
|
| 171 |
+
if "Solution:" in text:
|
| 172 |
+
return text.split("Solution:")[-1].strip()
|
| 173 |
+
return text.strip()
|
| 174 |
+
|
| 175 |
+
def check_answer(self, predicted: str, expected: str) -> bool:
|
| 176 |
+
# Normalize whitespace and case
|
| 177 |
+
pred = " ".join(predicted.lower().split())
|
| 178 |
+
exp = " ".join(expected.lower().split())
|
| 179 |
+
return pred == exp
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class ChemistryBenchmark(ScienceBenchmark):
|
| 183 |
+
"""Chemistry benchmark."""
|
| 184 |
+
|
| 185 |
+
def __init__(self):
|
| 186 |
+
super().__init__("chemistry_benchmark", "chemistry")
|
| 187 |
+
|
| 188 |
+
def load_questions(self) -> List[Dict]:
|
| 189 |
+
return [
|
| 190 |
+
{
|
| 191 |
+
"question": "What is the chemical formula for water?",
|
| 192 |
+
"answer": "H2O",
|
| 193 |
+
"type": "factual",
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"question": "How many protons does carbon have?",
|
| 197 |
+
"answer": "6",
|
| 198 |
+
"type": "factual",
|
| 199 |
+
},
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
def format_prompt(self, question: Dict) -> str:
|
| 203 |
+
return f"Chemistry question: {question['question']}\nAnswer:"
|
| 204 |
+
|
| 205 |
+
def extract_answer(self, text: str) -> str:
|
| 206 |
+
if "Answer:" in text:
|
| 207 |
+
return text.split("Answer:")[-1].strip()
|
| 208 |
+
return text.strip()
|
| 209 |
+
|
| 210 |
+
def check_answer(self, predicted: str, expected: str) -> bool:
|
| 211 |
+
pred = predicted.strip().lower()
|
| 212 |
+
exp = expected.strip().lower()
|
| 213 |
+
return exp in pred
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class BiologyBenchmark(ScienceBenchmark):
|
| 217 |
+
"""Biology benchmark."""
|
| 218 |
+
|
| 219 |
+
def __init__(self):
|
| 220 |
+
super().__init__("biology_benchmark", "biology")
|
| 221 |
+
|
| 222 |
+
def load_questions(self) -> List[Dict]:
|
| 223 |
+
return [
|
| 224 |
+
{
|
| 225 |
+
"question": "What is the powerhouse of the cell?",
|
| 226 |
+
"answer": "mitochondria",
|
| 227 |
+
"type": "factual",
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"question": "What molecule carries genetic information?",
|
| 231 |
+
"answer": "DNA",
|
| 232 |
+
"type": "factual",
|
| 233 |
+
},
|
| 234 |
+
]
|
| 235 |
+
|
| 236 |
+
def format_prompt(self, question: Dict) -> str:
|
| 237 |
+
return f"Biology: {question['question']}\nAnswer:"
|
| 238 |
+
|
| 239 |
+
def extract_answer(self, text: str) -> str:
|
| 240 |
+
if "Answer:" in text:
|
| 241 |
+
return text.split("Answer:")[-1].strip()
|
| 242 |
+
return text.strip()
|
| 243 |
+
|
| 244 |
+
def check_answer(self, predicted: str, expected: str) -> bool:
|
| 245 |
+
pred = predicted.strip().lower()
|
| 246 |
+
exp = expected.strip().lower()
|
| 247 |
+
return exp in pred
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# Placeholder for other domains
|
| 251 |
+
class EarthScienceBenchmark(ScienceBenchmark):
|
| 252 |
+
def __init__(self):
|
| 253 |
+
super().__init__("earth_science_benchmark", "earth")
|
| 254 |
+
|
| 255 |
+
def load_questions(self) -> List[Dict]:
|
| 256 |
+
return []
|
| 257 |
+
|
| 258 |
+
def format_prompt(self, question: Dict) -> str:
|
| 259 |
+
return f"Earth Science: {question['question']}\nAnswer:"
|
| 260 |
+
|
| 261 |
+
def extract_answer(self, text: str) -> str:
|
| 262 |
+
return text.strip()
|
| 263 |
+
|
| 264 |
+
def check_answer(self, predicted: str, expected: str) -> bool:
|
| 265 |
+
return predicted.strip().lower() == expected.strip().lower()
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class SpaceScienceBenchmark(ScienceBenchmark):
|
| 269 |
+
def __init__(self):
|
| 270 |
+
super().__init__("space_science_benchmark", "space")
|
| 271 |
+
|
| 272 |
+
def load_questions(self) -> List[Dict]:
|
| 273 |
+
return []
|
| 274 |
+
|
| 275 |
+
def format_prompt(self, question: Dict) -> str:
|
| 276 |
+
return f"Space Science: {question['question']}\nAnswer:"
|
| 277 |
+
|
| 278 |
+
def extract_answer(self, text: str) -> str:
|
| 279 |
+
return text.strip()
|
| 280 |
+
|
| 281 |
+
def check_answer(self, predicted: str, expected: str) -> bool:
|
| 282 |
+
return predicted.strip().lower() == expected.strip().lower()
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class ZoologyBenchmark(ScienceBenchmark):
|
| 286 |
+
def __init__(self):
|
| 287 |
+
super().__init__("zoology_benchmark", "zoology")
|
| 288 |
+
|
| 289 |
+
def load_questions(self) -> List[Dict]:
|
| 290 |
+
return []
|
| 291 |
+
|
| 292 |
+
def format_prompt(self, question: Dict) -> str:
|
| 293 |
+
return f"Zoology: {question['question']}\nAnswer:"
|
| 294 |
+
|
| 295 |
+
def extract_answer(self, text: str) -> str:
|
| 296 |
+
return text.strip()
|
| 297 |
+
|
| 298 |
+
def check_answer(self, predicted: str, expected: str) -> bool:
|
| 299 |
+
return predicted.strip().lower() == expected.strip().lower()
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def run_all_benchmarks(
|
| 303 |
+
model,
|
| 304 |
+
tokenizer,
|
| 305 |
+
device: torch.device,
|
| 306 |
+
max_samples_per_domain: int = 50,
|
| 307 |
+
) -> Dict[str, BenchmarkResult]:
|
| 308 |
+
"""
|
| 309 |
+
Run all benchmarks and return results.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
model: Vortex model
|
| 313 |
+
tokenizer: Tokenizer
|
| 314 |
+
device: Torch device
|
| 315 |
+
max_samples_per_domain: Max samples per domain
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
Dictionary mapping domain to results
|
| 319 |
+
"""
|
| 320 |
+
benchmarks = [
|
| 321 |
+
PhysicsBenchmark(),
|
| 322 |
+
MathBenchmark(),
|
| 323 |
+
ChemistryBenchmark(),
|
| 324 |
+
BiologyBenchmark(),
|
| 325 |
+
EarthScienceBenchmark(),
|
| 326 |
+
SpaceScienceBenchmark(),
|
| 327 |
+
ZoologyBenchmark(),
|
| 328 |
+
]
|
| 329 |
+
|
| 330 |
+
results = {}
|
| 331 |
+
for bench in benchmarks:
|
| 332 |
+
print(f"Running {bench.name}...")
|
| 333 |
+
result = bench.evaluate(model, tokenizer, device, max_samples=max_samples_per_domain)
|
| 334 |
+
results[bench.domain] = result
|
| 335 |
+
print(f" Accuracy: {result.accuracy:.2%} ({result.correct_answers}/{result.total_questions})")
|
| 336 |
+
|
| 337 |
+
return results
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def print_summary(results: Dict[str, BenchmarkResult]):
|
| 341 |
+
"""Print summary of benchmark results."""
|
| 342 |
+
print("\n" + "="*60)
|
| 343 |
+
print("BENCHMARK RESULTS")
|
| 344 |
+
print("="*60)
|
| 345 |
+
|
| 346 |
+
for domain, result in results.items():
|
| 347 |
+
print(f"{domain:15} {result.accuracy:6.2%} ({result.correct_answers}/{result.total_questions})")
|
| 348 |
+
|
| 349 |
+
# Overall average
|
| 350 |
+
all_accuracies = [r.accuracy for r in results.values() if r.total_questions > 0]
|
| 351 |
+
if all_accuracies:
|
| 352 |
+
avg = sum(all_accuracies) / len(all_accuracies)
|
| 353 |
+
print(f"{'OVERALL':15} {avg:6.2%}")
|
| 354 |
+
print("="*60)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
if __name__ == "__main__":
|
| 358 |
+
# Quick test
|
| 359 |
+
print("This script benchmarks the model across science domains.")
|
| 360 |
+
print("To run full benchmarks, integrate with a trained model.")
|
test_model.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Comprehensive unit tests for Vortex model components.
|
| 4 |
+
Run with: python -m pytest test_model.py -v
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
import torch
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# Add Vortex to path
|
| 13 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def test_tokenizer():
|
| 17 |
+
"""Test VortexScienceTokenizer."""
|
| 18 |
+
from tokenizer.vortex_tokenizer import VortexScienceTokenizer
|
| 19 |
+
from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| 20 |
+
|
| 21 |
+
tokenizer = VortexScienceTokenizer(VORTEX_7B_CONFIG)
|
| 22 |
+
|
| 23 |
+
# Test encoding/decoding
|
| 24 |
+
text = "The equation is $E = mc^2$ and H2O is water."
|
| 25 |
+
encoded = tokenizer.encode(text, return_tensors="pt")
|
| 26 |
+
assert "input_ids" in encoded
|
| 27 |
+
assert encoded["input_ids"].shape[0] == 1 # batch dim
|
| 28 |
+
|
| 29 |
+
decoded = tokenizer.decode(encoded["input_ids"][0].tolist())
|
| 30 |
+
assert isinstance(decoded, str)
|
| 31 |
+
print("✓ Tokenizer test passed")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test_ssm_layer():
|
| 35 |
+
"""Test VortexSSM."""
|
| 36 |
+
from models.ssm_layer import VortexSSM
|
| 37 |
+
|
| 38 |
+
batch_size = 2
|
| 39 |
+
seq_len = 64
|
| 40 |
+
d_model = 512
|
| 41 |
+
d_state = 16
|
| 42 |
+
|
| 43 |
+
ssm = VortexSSM(d_model, d_state=d_state)
|
| 44 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 45 |
+
|
| 46 |
+
# Forward pass
|
| 47 |
+
output = ssm(x)
|
| 48 |
+
assert output.shape == x.shape
|
| 49 |
+
|
| 50 |
+
# Stateful forward
|
| 51 |
+
state = torch.zeros(batch_size, ssm.d_inner, d_state)
|
| 52 |
+
output2, new_state = ssm(x, state=state, return_state=True)
|
| 53 |
+
assert output2.shape == x.shape
|
| 54 |
+
assert new_state.shape == (batch_size, ssm.d_inner, d_state)
|
| 55 |
+
|
| 56 |
+
# Single step
|
| 57 |
+
x_step = torch.randn(batch_size, d_model)
|
| 58 |
+
output_step, state_step = ssm.step(x_step, state)
|
| 59 |
+
assert output_step.shape == (batch_size, d_model)
|
| 60 |
+
assert state_step.shape == (batch_size, ssm.d_inner, d_state)
|
| 61 |
+
|
| 62 |
+
print("✓ SSM layer test passed")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_attention_layer():
|
| 66 |
+
"""Test VortexLocalAttention."""
|
| 67 |
+
from models.attention_layer import VortexLocalAttention
|
| 68 |
+
|
| 69 |
+
batch_size = 2
|
| 70 |
+
seq_len = 128
|
| 71 |
+
d_model = 512
|
| 72 |
+
num_heads = 8
|
| 73 |
+
|
| 74 |
+
attn = VortexLocalAttention(d_model, num_heads, window_size=64, use_flash_attention=False)
|
| 75 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 76 |
+
|
| 77 |
+
# Forward pass
|
| 78 |
+
output = attn(x)
|
| 79 |
+
assert output.shape == x.shape
|
| 80 |
+
|
| 81 |
+
# With global mask
|
| 82 |
+
global_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
|
| 83 |
+
global_mask[0, 0] = True
|
| 84 |
+
output2 = attn(x, global_mask=global_mask)
|
| 85 |
+
assert output2.shape == x.shape
|
| 86 |
+
|
| 87 |
+
print("✓ Local attention test passed")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def test_scigate_ffn():
|
| 91 |
+
"""Test SciGateFFN."""
|
| 92 |
+
from models.scigate_ffn import SciGateFFN
|
| 93 |
+
|
| 94 |
+
batch_size = 2
|
| 95 |
+
seq_len = 64
|
| 96 |
+
d_model = 512
|
| 97 |
+
num_domains = 7
|
| 98 |
+
|
| 99 |
+
ffn = SciGateFFN(d_model, expansion=4, num_domains=num_domains)
|
| 100 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 101 |
+
|
| 102 |
+
# Without domain info
|
| 103 |
+
output = ffn(x)
|
| 104 |
+
assert output.shape == x.shape
|
| 105 |
+
|
| 106 |
+
# With domain IDs
|
| 107 |
+
domain_ids = torch.randint(0, num_domains, (batch_size,))
|
| 108 |
+
output2 = ffn(x, domain_ids=domain_ids)
|
| 109 |
+
assert output2.shape == x.shape
|
| 110 |
+
|
| 111 |
+
# With domain tags
|
| 112 |
+
domain_tags = torch.zeros(batch_size, seq_len, num_domains)
|
| 113 |
+
domain_tags[:, :, 0] = 1.0
|
| 114 |
+
output3 = ffn(x, domain_tags=domain_tags)
|
| 115 |
+
assert output3.shape == x.shape
|
| 116 |
+
|
| 117 |
+
print("✓ SciGate FFN test passed")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def test_equation_module():
|
| 121 |
+
"""Test EquationModule."""
|
| 122 |
+
from models.science_modules.equation_module import EquationModule
|
| 123 |
+
|
| 124 |
+
d_model = 512
|
| 125 |
+
batch_size = 2
|
| 126 |
+
seq_len = 64
|
| 127 |
+
|
| 128 |
+
module = EquationModule(d_model)
|
| 129 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 130 |
+
text = ["E = mc^2 is famous.", "The integral $\\int x dx = x^2/2$."]
|
| 131 |
+
|
| 132 |
+
output = module(x, text=text)
|
| 133 |
+
assert output.shape == x.shape
|
| 134 |
+
|
| 135 |
+
# Test equation loss
|
| 136 |
+
equation_mask = torch.zeros(batch_size, seq_len)
|
| 137 |
+
equation_mask[0, 5:10] = 1.0
|
| 138 |
+
loss = module.compute_equation_loss(x, equation_mask)
|
| 139 |
+
assert loss.item() >= 0
|
| 140 |
+
|
| 141 |
+
print("✓ Equation module test passed")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def test_numerical_module():
|
| 145 |
+
"""Test NumericalReasoningModule."""
|
| 146 |
+
from models.science_modules.numerical_module import NumericalReasoningModule
|
| 147 |
+
|
| 148 |
+
d_model = 512
|
| 149 |
+
batch_size = 2
|
| 150 |
+
seq_len = 64
|
| 151 |
+
|
| 152 |
+
module = NumericalReasoningModule(d_model)
|
| 153 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 154 |
+
text = ["Speed of light: 2.998e8 m/s", "6.022e23 is Avogadro's number."]
|
| 155 |
+
|
| 156 |
+
output = module(x, text=text)
|
| 157 |
+
assert output.shape == x.shape
|
| 158 |
+
|
| 159 |
+
print("✓ Numerical reasoning module test passed")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def test_citation_module():
|
| 163 |
+
"""Test CitationModule."""
|
| 164 |
+
from models.science_modules.citation_module import CitationModule
|
| 165 |
+
|
| 166 |
+
d_model = 512
|
| 167 |
+
batch_size = 2
|
| 168 |
+
seq_len = 64
|
| 169 |
+
|
| 170 |
+
module = CitationModule(d_model)
|
| 171 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 172 |
+
text = ["(Einstein, 1905) changed physics.", "See also [1, 2] for details."]
|
| 173 |
+
|
| 174 |
+
output, confidence = module(x, text=text)
|
| 175 |
+
assert output.shape == x.shape
|
| 176 |
+
assert confidence.shape == (batch_size, seq_len, 1)
|
| 177 |
+
|
| 178 |
+
# Test loss
|
| 179 |
+
citation_mask = torch.zeros(batch_size, seq_len)
|
| 180 |
+
citation_mask[0, 0:5] = 1.0
|
| 181 |
+
loss = module.compute_citation_loss(x, citation_mask, confidence)
|
| 182 |
+
assert loss.item() >= 0
|
| 183 |
+
|
| 184 |
+
print("✓ Citation module test passed")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def test_molecular_module():
|
| 188 |
+
"""Test MolecularModule."""
|
| 189 |
+
from models.science_modules.molecular_module import MolecularModule
|
| 190 |
+
|
| 191 |
+
d_model = 512
|
| 192 |
+
batch_size = 2
|
| 193 |
+
seq_len = 64
|
| 194 |
+
|
| 195 |
+
module = MolecularModule(d_model)
|
| 196 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 197 |
+
text = ["H2O is water.", "DNA sequence: ACGTACGT"]
|
| 198 |
+
|
| 199 |
+
output = module(x, text=text)
|
| 200 |
+
assert output.shape == x.shape
|
| 201 |
+
|
| 202 |
+
print("✓ Molecular module test passed")
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def test_vortex_model():
|
| 206 |
+
"""Test full VortexModel."""
|
| 207 |
+
from models.vortex_model import VortexModel
|
| 208 |
+
from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| 209 |
+
|
| 210 |
+
# Small config for testing
|
| 211 |
+
config = VORTEX_7B_CONFIG.copy()
|
| 212 |
+
config["d_model"] = 256
|
| 213 |
+
config["num_layers"] = 4
|
| 214 |
+
config["num_heads"] = 4
|
| 215 |
+
config["vocab_size"] = 1000
|
| 216 |
+
|
| 217 |
+
model = VortexModel(config)
|
| 218 |
+
|
| 219 |
+
batch_size = 2
|
| 220 |
+
seq_len = 32
|
| 221 |
+
input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
|
| 222 |
+
|
| 223 |
+
# Forward pass
|
| 224 |
+
output = model(input_ids)
|
| 225 |
+
logits = output["logits"]
|
| 226 |
+
assert logits.shape == (batch_size, seq_len, config["vocab_size"])
|
| 227 |
+
|
| 228 |
+
# Count parameters
|
| 229 |
+
num_params = model.get_num_params()
|
| 230 |
+
assert num_params > 0
|
| 231 |
+
|
| 232 |
+
print(f"✓ VortexModel test passed (params: {num_params:,})")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def test_quality_filter():
|
| 236 |
+
"""Test ScienceQualityFilter."""
|
| 237 |
+
from data.quality_filter import ScienceQualityFilter
|
| 238 |
+
|
| 239 |
+
filter = ScienceQualityFilter()
|
| 240 |
+
|
| 241 |
+
# Good text
|
| 242 |
+
good_text = """
|
| 243 |
+
The experiment collected data from 100 participants. Results show a
|
| 244 |
+
significant effect (p < 0.05). The equation E = mc^2 is fundamental.
|
| 245 |
+
According to Smith et al., this confirms the hypothesis.
|
| 246 |
+
"""
|
| 247 |
+
assert filter.filter(good_text)
|
| 248 |
+
|
| 249 |
+
# Bad: too short
|
| 250 |
+
assert not filter.filter("Too short.")
|
| 251 |
+
|
| 252 |
+
# Bad: unmatched equations
|
| 253 |
+
bad_eq = "Equation $E = mc^2 and another $F = ma."
|
| 254 |
+
assert not filter.filter(bad_eq)
|
| 255 |
+
|
| 256 |
+
print("✓ Quality filter test passed")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def test_domain_classifier():
|
| 260 |
+
"""Test DomainClassifier."""
|
| 261 |
+
from data.domain_classifier import DomainClassifier
|
| 262 |
+
|
| 263 |
+
d_model = 256
|
| 264 |
+
classifier = DomainClassifier(d_model)
|
| 265 |
+
|
| 266 |
+
# Test with random hidden states
|
| 267 |
+
batch_size = 4
|
| 268 |
+
seq_len = 32
|
| 269 |
+
hidden = torch.randn(batch_size, seq_len, d_model)
|
| 270 |
+
logits = classifier(hidden)
|
| 271 |
+
assert logits.shape == (batch_size, 7)
|
| 272 |
+
|
| 273 |
+
# Test text classification
|
| 274 |
+
text = "Quantum mechanics describes particle behavior."
|
| 275 |
+
domain, conf = classifier.classify_text(text)
|
| 276 |
+
assert domain in range(7)
|
| 277 |
+
assert 0 <= conf <= 1
|
| 278 |
+
|
| 279 |
+
print("✓ Domain classifier test passed")
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def test_deduplication():
|
| 283 |
+
"""Test MinHashLSH."""
|
| 284 |
+
from data.deduplication import MinHashLSH
|
| 285 |
+
|
| 286 |
+
lsh = MinHashLSH(num_permutations=32, threshold=0.7, bands=4, rows_per_band=8)
|
| 287 |
+
|
| 288 |
+
docs = [
|
| 289 |
+
("doc1", "The quick brown fox jumps over the lazy dog."),
|
| 290 |
+
("doc2", "The quick brown fox jumps over the lazy dog!!!"),
|
| 291 |
+
("doc3", "Completely different text about science."),
|
| 292 |
+
]
|
| 293 |
+
|
| 294 |
+
for doc_id, text in docs:
|
| 295 |
+
lsh.add_document(doc_id, text)
|
| 296 |
+
|
| 297 |
+
# Query similar
|
| 298 |
+
results = lsh.query(docs[0][1])
|
| 299 |
+
# Should find doc2 as similar
|
| 300 |
+
assert len(results) >= 1
|
| 301 |
+
assert any(r[0] == "doc2" for r in results)
|
| 302 |
+
|
| 303 |
+
print("✓ Deduplication test passed")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def test_losses():
|
| 307 |
+
"""Test VortexLoss."""
|
| 308 |
+
from training.losses import VortexLoss
|
| 309 |
+
|
| 310 |
+
config = {"loss_weights": {
|
| 311 |
+
"lm_loss": 1.0,
|
| 312 |
+
"equation_loss": 0.3,
|
| 313 |
+
"domain_loss": 0.1,
|
| 314 |
+
"citation_loss": 0.1,
|
| 315 |
+
"numerical_loss": 0.2,
|
| 316 |
+
}}
|
| 317 |
+
|
| 318 |
+
loss_fn = VortexLoss(config)
|
| 319 |
+
|
| 320 |
+
batch_size = 2
|
| 321 |
+
seq_len = 32
|
| 322 |
+
vocab_size = 1000
|
| 323 |
+
|
| 324 |
+
logits = torch.randn(batch_size, seq_len, vocab_size)
|
| 325 |
+
labels = torch.randint(0, vocab_size, (batch_size, seq_len))
|
| 326 |
+
|
| 327 |
+
losses = loss_fn(logits, labels)
|
| 328 |
+
assert "total_loss" in losses
|
| 329 |
+
assert "lm_loss" in losses
|
| 330 |
+
assert losses["total_loss"].item() > 0
|
| 331 |
+
|
| 332 |
+
print("✓ Losses test passed")
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def test_curriculum():
|
| 336 |
+
"""Test CurriculumScheduler."""
|
| 337 |
+
from training.curriculum import CurriculumScheduler
|
| 338 |
+
|
| 339 |
+
config = {
|
| 340 |
+
"curriculum_stages": [
|
| 341 |
+
{"name": "foundation", "start": 0.0, "end": 0.2},
|
| 342 |
+
{"name": "domain", "start": 0.2, "end": 0.5},
|
| 343 |
+
{"name": "reasoning", "start": 0.5, "end": 0.8},
|
| 344 |
+
{"name": "integration", "start": 0.8, "end": 1.0},
|
| 345 |
+
]
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
total_steps = 1000
|
| 349 |
+
scheduler = CurriculumScheduler(config, total_steps)
|
| 350 |
+
|
| 351 |
+
# Test stage at different steps
|
| 352 |
+
assert scheduler.get_stage_name(0) == "foundation"
|
| 353 |
+
assert scheduler.get_stage_name(250) == "domain"
|
| 354 |
+
assert scheduler.get_stage_name(500) == "reasoning"
|
| 355 |
+
assert scheduler.get_stage_name(800) == "integration"
|
| 356 |
+
|
| 357 |
+
# Test sampler
|
| 358 |
+
weights = scheduler.get_dataset_sampler(100)
|
| 359 |
+
assert isinstance(weights, dict)
|
| 360 |
+
assert sum(weights.values()) == 1.0
|
| 361 |
+
|
| 362 |
+
print("✓ Curriculum test passed")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def test_hf_integration():
|
| 366 |
+
"""Test HuggingFace integration."""
|
| 367 |
+
from configuration_vortex import VortexConfig
|
| 368 |
+
from modeling_vortex import VortexForCausalLM
|
| 369 |
+
from tokenization_vortex import VortexTokenizer
|
| 370 |
+
|
| 371 |
+
# Config
|
| 372 |
+
config = VortexConfig(
|
| 373 |
+
d_model=128,
|
| 374 |
+
num_layers=2,
|
| 375 |
+
num_heads=4,
|
| 376 |
+
vocab_size=100,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
# Model
|
| 380 |
+
model = VortexForCausalLM(config)
|
| 381 |
+
batch_size = 2
|
| 382 |
+
seq_len = 16
|
| 383 |
+
input_ids = torch.randint(0, 100, (batch_size, seq_len))
|
| 384 |
+
|
| 385 |
+
outputs = model(input_ids)
|
| 386 |
+
assert outputs.logits.shape == (batch_size, seq_len, 100)
|
| 387 |
+
|
| 388 |
+
# Save and load
|
| 389 |
+
model.save_pretrained("./test_hf_model")
|
| 390 |
+
config.save_pretrained("./test_hf_model")
|
| 391 |
+
|
| 392 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 393 |
+
loaded_config = AutoConfig.from_pretrained("./test_hf_model")
|
| 394 |
+
loaded_model = AutoModelForCausalLM.from_pretrained("./test_hf_model")
|
| 395 |
+
|
| 396 |
+
assert loaded_config.model_type == "vortex"
|
| 397 |
+
assert isinstance(loaded_model, VortexForCausalLM)
|
| 398 |
+
|
| 399 |
+
# Cleanup
|
| 400 |
+
import shutil
|
| 401 |
+
shutil.rmtree("./test_hf_model")
|
| 402 |
+
|
| 403 |
+
print("✓ HuggingFace integration test passed")
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def run_all_tests():
|
| 407 |
+
"""Run all tests."""
|
| 408 |
+
tests = [
|
| 409 |
+
test_tokenizer,
|
| 410 |
+
test_ssm_layer,
|
| 411 |
+
test_attention_layer,
|
| 412 |
+
test_scigate_ffn,
|
| 413 |
+
test_equation_module,
|
| 414 |
+
test_numerical_module,
|
| 415 |
+
test_citation_module,
|
| 416 |
+
test_molecular_module,
|
| 417 |
+
test_vortex_model,
|
| 418 |
+
test_quality_filter,
|
| 419 |
+
test_domain_classifier,
|
| 420 |
+
test_deduplication,
|
| 421 |
+
test_losses,
|
| 422 |
+
test_curriculum,
|
| 423 |
+
test_hf_integration,
|
| 424 |
+
]
|
| 425 |
+
|
| 426 |
+
print("Running Vortex unit tests...\n")
|
| 427 |
+
passed = 0
|
| 428 |
+
failed = 0
|
| 429 |
+
|
| 430 |
+
for test in tests:
|
| 431 |
+
try:
|
| 432 |
+
test()
|
| 433 |
+
passed += 1
|
| 434 |
+
except Exception as e:
|
| 435 |
+
print(f"✗ {test.__name__} failed: {e}")
|
| 436 |
+
failed += 1
|
| 437 |
+
import traceback
|
| 438 |
+
traceback.print_exc()
|
| 439 |
+
|
| 440 |
+
print(f"\n{'='*50}")
|
| 441 |
+
print(f"Tests: {passed + failed} total, {passed} passed, {failed} failed")
|
| 442 |
+
print(f"{'='*50}")
|
| 443 |
+
|
| 444 |
+
return failed == 0
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
if __name__ == "__main__":
|
| 448 |
+
success = run_all_tests()
|
| 449 |
+
sys.exit(0 if success else 1)
|
tokenization_vortex.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vortex tokenizer for HuggingFace.
|
| 3 |
+
Wraps VortexScienceTokenizer for HF compatibility.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import List, Optional, Dict, Any
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class VortexTokenizer:
|
| 12 |
+
"""
|
| 13 |
+
HuggingFace-compatible tokenizer for Vortex.
|
| 14 |
+
Wraps VortexScienceTokenizer.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
tokenizer_file: Optional[str] = None,
|
| 20 |
+
config: Optional[Dict] = None,
|
| 21 |
+
**kwargs,
|
| 22 |
+
):
|
| 23 |
+
"""
|
| 24 |
+
Initialize tokenizer.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
tokenizer_file: Path to tokenizer JSON
|
| 28 |
+
config: Tokenizer configuration
|
| 29 |
+
"""
|
| 30 |
+
from .tokenizer.vortex_tokenizer import VortexScienceTokenizer
|
| 31 |
+
|
| 32 |
+
self.config = config or {}
|
| 33 |
+
self.special_tokens = self.config.get("special_tokens", {})
|
| 34 |
+
|
| 35 |
+
if tokenizer_file and os.path.exists(tokenizer_file):
|
| 36 |
+
self.tokenizer = VortexScienceTokenizer(
|
| 37 |
+
self.config,
|
| 38 |
+
tokenizer_path=tokenizer_file,
|
| 39 |
+
)
|
| 40 |
+
else:
|
| 41 |
+
# Initialize empty - needs training
|
| 42 |
+
self.tokenizer = VortexScienceTokenizer(self.config)
|
| 43 |
+
|
| 44 |
+
# HF compatibility attributes
|
| 45 |
+
self.pad_token = "[PAD]"
|
| 46 |
+
self.unk_token = "[UNK]"
|
| 47 |
+
self.bos_token = "[BOS]"
|
| 48 |
+
self.eos_token = "[EOS]"
|
| 49 |
+
self.pad_token_id = self.special_tokens.get("[PAD]", 0)
|
| 50 |
+
self.unk_token_id = self.special_tokens.get("[UNK]", 1)
|
| 51 |
+
self.bos_token_id = self.special_tokens.get("[BOS]", 2)
|
| 52 |
+
self.eos_token_id = self.special_tokens.get("[EOS]", 3)
|
| 53 |
+
|
| 54 |
+
@classmethod
|
| 55 |
+
def from_pretrained(
|
| 56 |
+
cls,
|
| 57 |
+
pretrained_model_name_or_path: str,
|
| 58 |
+
**kwargs,
|
| 59 |
+
):
|
| 60 |
+
"""Load tokenizer from pretrained model."""
|
| 61 |
+
tokenizer_path = os.path.join(pretrained_model_name_or_path, "vortex_tokenizer.json")
|
| 62 |
+
config_path = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
|
| 63 |
+
|
| 64 |
+
config = {}
|
| 65 |
+
if os.path.exists(config_path):
|
| 66 |
+
with open(config_path, "r") as f:
|
| 67 |
+
config = json.load(f)
|
| 68 |
+
|
| 69 |
+
return cls(tokenizer_file=tokenizer_path, config=config, **kwargs)
|
| 70 |
+
|
| 71 |
+
def __call__(
|
| 72 |
+
self,
|
| 73 |
+
text: str | List[str],
|
| 74 |
+
padding: bool = False,
|
| 75 |
+
truncation: bool = False,
|
| 76 |
+
max_length: Optional[int] = None,
|
| 77 |
+
return_tensors: str = "pt",
|
| 78 |
+
**kwargs,
|
| 79 |
+
) -> Dict[str, Any]:
|
| 80 |
+
"""
|
| 81 |
+
Tokenize text.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
text: Input text or list of texts
|
| 85 |
+
padding: Pad to same length
|
| 86 |
+
truncation: Truncate to max_length
|
| 87 |
+
max_length: Maximum length
|
| 88 |
+
return_tensors: "pt" for PyTorch, "np" for numpy, None for list
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Dictionary with input_ids, attention_mask
|
| 92 |
+
"""
|
| 93 |
+
if isinstance(text, str):
|
| 94 |
+
text = [text]
|
| 95 |
+
|
| 96 |
+
if max_length is None:
|
| 97 |
+
max_length = self.config.get("max_seq_len", 16384)
|
| 98 |
+
|
| 99 |
+
# Use batch_encode
|
| 100 |
+
result = self.tokenizer.batch_encode(
|
| 101 |
+
text,
|
| 102 |
+
padding=padding,
|
| 103 |
+
truncation=truncation,
|
| 104 |
+
max_length=max_length,
|
| 105 |
+
return_tensors=return_tensors,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
return result
|
| 109 |
+
|
| 110 |
+
def encode(
|
| 111 |
+
self,
|
| 112 |
+
text: str,
|
| 113 |
+
add_special_tokens: bool = True,
|
| 114 |
+
**kwargs,
|
| 115 |
+
) -> List[int]:
|
| 116 |
+
"""Encode text to token IDs."""
|
| 117 |
+
result = self.tokenizer.encode(
|
| 118 |
+
text,
|
| 119 |
+
add_special_tokens=add_special_tokens,
|
| 120 |
+
return_tensors=None,
|
| 121 |
+
)
|
| 122 |
+
return result["input_ids"]
|
| 123 |
+
|
| 124 |
+
def decode(
|
| 125 |
+
self,
|
| 126 |
+
token_ids: List[int],
|
| 127 |
+
skip_special_tokens: bool = True,
|
| 128 |
+
**kwargs,
|
| 129 |
+
) -> str:
|
| 130 |
+
"""Decode token IDs to text."""
|
| 131 |
+
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
| 132 |
+
|
| 133 |
+
def save_pretrained(self, save_directory: str):
|
| 134 |
+
"""Save tokenizer to directory."""
|
| 135 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 136 |
+
tokenizer_path = os.path.join(save_directory, "vortex_tokenizer.json")
|
| 137 |
+
self.tokenizer.save(tokenizer_path)
|
| 138 |
+
|
| 139 |
+
# Save tokenizer config
|
| 140 |
+
config_path = os.path.join(save_directory, "tokenizer_config.json")
|
| 141 |
+
with open(config_path, "w") as f:
|
| 142 |
+
json.dump({
|
| 143 |
+
"model_type": "vortex",
|
| 144 |
+
"special_tokens": self.special_tokens,
|
| 145 |
+
}, f, indent=2)
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def vocab_size(self) -> int:
|
| 149 |
+
"""Get vocabulary size."""
|
| 150 |
+
return self.tokenizer.vocab_size
|
| 151 |
+
|
| 152 |
+
def get_vocab(self) -> Dict[str, int]:
|
| 153 |
+
"""Get vocabulary dictionary."""
|
| 154 |
+
return self.tokenizer.get_vocab()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def test_vortex_tokenizer():
|
| 158 |
+
"""Test VortexTokenizer."""
|
| 159 |
+
from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| 160 |
+
|
| 161 |
+
tokenizer = VortexTokenizer(config=VORTEX_7B_CONFIG)
|
| 162 |
+
|
| 163 |
+
text = "The equation is $E = mc^2$ and the reaction is H2O."
|
| 164 |
+
encoded = tokenizer(text, padding=False, truncation=True, max_length=128)
|
| 165 |
+
print(f"Encoded: {encoded['input_ids'][0][:10]}...")
|
| 166 |
+
|
| 167 |
+
decoded = tokenizer.decode(encoded["input_ids"][0])
|
| 168 |
+
print(f"Decoded: {decoded[:50]}...")
|
| 169 |
+
|
| 170 |
+
print("VortexTokenizer test passed!")
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
test_vortex_tokenizer()
|
tokenizer/__pycache__/vortex_tokenizer.cpython-313.pyc
ADDED
|
Binary file (17.4 kB). View file
|
|
|
tokenizer/vortex_tokenizer.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VortexScienceTokenizer: A custom BPE tokenizer optimized for scientific text.
|
| 3 |
+
Trains on science corpus and extends vocabulary with domain-specific tokens.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import re
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import List, Dict, Optional, Tuple, Union
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from tokenizers import Tokenizer, models, pre_tokenizers, processors, trainers
|
| 15 |
+
from tokenizers.normalizers import Lowercase, NFD, StripAccents
|
| 16 |
+
except ImportError:
|
| 17 |
+
print("Please install tokenizers: pip install tokenizers")
|
| 18 |
+
raise
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class VortexScienceTokenizer:
|
| 22 |
+
"""
|
| 23 |
+
Science-optimized BPE tokenizer with domain extensions.
|
| 24 |
+
|
| 25 |
+
Features:
|
| 26 |
+
- Base BPE vocabulary (40,000 tokens) trained on scientific corpus
|
| 27 |
+
- Extended science vocabulary (10,000 tokens) for LaTeX, chemistry, units, etc.
|
| 28 |
+
- Special tokens for equation/citation/molecule spans
|
| 29 |
+
- Domain tags for science areas
|
| 30 |
+
- Digit-level number handling (optional, can be toggled)
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
config: Dict,
|
| 36 |
+
tokenizer_path: Optional[str] = None,
|
| 37 |
+
vocab_size: int = 50000,
|
| 38 |
+
base_vocab_size: int = 40000,
|
| 39 |
+
extension_vocab_size: int = 10000,
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Initialize the tokenizer.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
config: Model configuration with special tokens
|
| 46 |
+
tokenizer_path: Path to pre-trained tokenizer (if loading)
|
| 47 |
+
vocab_size: Total vocabulary size
|
| 48 |
+
base_vocab_size: Size of base BPE vocabulary
|
| 49 |
+
extension_vocab_size: Size of science extension vocabulary
|
| 50 |
+
"""
|
| 51 |
+
self.config = config
|
| 52 |
+
self.base_vocab_size = base_vocab_size
|
| 53 |
+
self.extension_vocab_size = extension_vocab_size
|
| 54 |
+
self._vocab_size = vocab_size
|
| 55 |
+
|
| 56 |
+
self.special_tokens = config.get("special_tokens", {})
|
| 57 |
+
self.domain_tags = config.get("domain_tags", [])
|
| 58 |
+
|
| 59 |
+
if tokenizer_path and os.path.exists(tokenizer_path):
|
| 60 |
+
self.tokenizer = Tokenizer.from_file(tokenizer_path)
|
| 61 |
+
print(f"Loaded tokenizer from {tokenizer_path}")
|
| 62 |
+
else:
|
| 63 |
+
# Initialize empty BPE tokenizer
|
| 64 |
+
self.tokenizer = Tokenizer(models.BPE())
|
| 65 |
+
self._setup_pre_tokenizer()
|
| 66 |
+
print("Initialized empty BPE tokenizer")
|
| 67 |
+
|
| 68 |
+
def _setup_pre_tokenizer(self):
|
| 69 |
+
"""Configure pre-tokenization rules."""
|
| 70 |
+
# Use byte-level pre-tokenization for robustness
|
| 71 |
+
self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
| 72 |
+
self.tokenizer.normalizer = None # Keep original casing for science terms
|
| 73 |
+
|
| 74 |
+
def train(
|
| 75 |
+
self,
|
| 76 |
+
file_paths: List[str],
|
| 77 |
+
min_frequency: int = 2,
|
| 78 |
+
special_tokens: Optional[List[str]] = None,
|
| 79 |
+
):
|
| 80 |
+
"""
|
| 81 |
+
Train the BPE tokenizer on scientific text files.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
file_paths: List of text file paths for training
|
| 85 |
+
min_frequency: Minimum token frequency to keep
|
| 86 |
+
special_tokens: Additional special tokens to add
|
| 87 |
+
"""
|
| 88 |
+
if special_tokens is None:
|
| 89 |
+
special_tokens = list(self.special_tokens.keys()) + self.domain_tags
|
| 90 |
+
|
| 91 |
+
print(f"Training tokenizer on {len(file_paths)} files...")
|
| 92 |
+
print(f"Base vocab size: {self.base_vocab_size}")
|
| 93 |
+
print(f"Special tokens: {special_tokens}")
|
| 94 |
+
|
| 95 |
+
trainer = trainers.BpeTrainer(
|
| 96 |
+
vocab_size=self.base_vocab_size,
|
| 97 |
+
min_frequency=min_frequency,
|
| 98 |
+
special_tokens=special_tokens,
|
| 99 |
+
show_progress=True,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self.tokenizer.train(file_paths, trainer=trainer)
|
| 103 |
+
print(f"Training complete. Vocabulary size: {self.tokenizer.get_vocab_size()}")
|
| 104 |
+
|
| 105 |
+
# Extend with science-specific tokens
|
| 106 |
+
self._extend_science_vocabulary()
|
| 107 |
+
|
| 108 |
+
def _extend_science_vocabulary(self):
|
| 109 |
+
"""Add science-specific tokens to the vocabulary."""
|
| 110 |
+
current_vocab = self.tokenizer.get_vocab()
|
| 111 |
+
new_tokens = []
|
| 112 |
+
|
| 113 |
+
# LaTeX math symbols (common ones)
|
| 114 |
+
latex_symbols = [
|
| 115 |
+
"\\alpha", "\\beta", "\\gamma", "\\delta", "\\epsilon", "\\zeta",
|
| 116 |
+
"\\eta", "\\theta", "\\iota", "\\kappa", "\\lambda", "\\mu",
|
| 117 |
+
"\\nu", "\\xi", "\\pi", "\\rho", "\\sigma", "\\tau",
|
| 118 |
+
"\\upsilon", "\\phi", "\\chi", "\\psi", "\\omega",
|
| 119 |
+
"\\Gamma", "\\Delta", "\\Theta", "\\Lambda", "\\Xi", "\\Pi",
|
| 120 |
+
"\\Sigma", "\\Phi", "\\Psi", "\\Omega",
|
| 121 |
+
"\\sum", "\\prod", "\\int", "\\partial", "\\nabla", "\\infty",
|
| 122 |
+
"\\leq", "\\geq", "\\neq", "\\approx", "\\equiv", "\\sim",
|
| 123 |
+
"\\in", "\\notin", "\\subset", "\\supset", "\\cup", "\\cap",
|
| 124 |
+
"\\forall", "\\exists", "\\neg", "\\land", "\\lor", "\\rightarrow",
|
| 125 |
+
"\\leftarrow", "\\Rightarrow", "\\Leftarrow", "\\leftrightarrow",
|
| 126 |
+
"\\frac", "\\sqrt", "\\binom", "\\begin", "\\end", "\\mathbf",
|
| 127 |
+
"\\mathcal", "\\mathrm", "\\mathbb", "\\mathfrak",
|
| 128 |
+
]
|
| 129 |
+
new_tokens.extend(latex_symbols)
|
| 130 |
+
|
| 131 |
+
# Greek letters (Unicode)
|
| 132 |
+
greek_letters = [
|
| 133 |
+
"α", "β", "γ", "δ", "ε", "ζ", "η", "θ", "ι", "κ", "λ", "μ",
|
| 134 |
+
"ν", "ξ", "ο", "π", "ρ", "σ", "τ", "υ", "φ", "χ", "ψ", "ω",
|
| 135 |
+
"Γ", "Δ", "Θ", "Λ", "Ξ", "Π", "Σ", "Φ", "Ψ", "Ω",
|
| 136 |
+
]
|
| 137 |
+
new_tokens.extend(greek_letters)
|
| 138 |
+
|
| 139 |
+
# SI units and derived units
|
| 140 |
+
si_units = [
|
| 141 |
+
"m", "kg", "s", "mol", "K", "A", "cd", "mol",
|
| 142 |
+
"Hz", "N", "Pa", "J", "W", "C", "V", "F", "Ω", "S",
|
| 143 |
+
"Wb", "T", "H", "lm", "lx", "Bq", "Gy", "Sv", "kat",
|
| 144 |
+
"eV", "u", "Da", "Å", "°C", "%", "‰",
|
| 145 |
+
"M", "mM", "μM", "nM", "pM",
|
| 146 |
+
"g", "mg", "μg", "ng", "pg",
|
| 147 |
+
"km", "m", "cm", "mm", "μm", "nm", "pm",
|
| 148 |
+
"L", "mL", "μL", "nL",
|
| 149 |
+
"h", "min", "s", "ms", "μs", "ns",
|
| 150 |
+
]
|
| 151 |
+
new_tokens.extend(si_units)
|
| 152 |
+
|
| 153 |
+
# Common scientific abbreviations
|
| 154 |
+
sci_abbrevs = [
|
| 155 |
+
"DNA", "RNA", "mRNA", "tRNA", "rRNA", "cDNA", "gDNA",
|
| 156 |
+
"ATP", "ADP", "AMP", "NAD", "NADP", "FAD", "CoA",
|
| 157 |
+
"pH", "pKa", "pKb", "pI",
|
| 158 |
+
"PCR", "RT", "qPCR", "NGS", "WGS",
|
| 159 |
+
"IC50", "EC50", "KD", "Ki",
|
| 160 |
+
"XRD", "NMR", "IR", "UV", "VIS", "MS", "GC", "HPLC",
|
| 161 |
+
"SEM", "TEM", "AFM", "STM",
|
| 162 |
+
"S/N", "SNR", "RMS", "Std", "Var", "Cov",
|
| 163 |
+
"et al.", "vs.", "cf.", "viz.",
|
| 164 |
+
"Fig", "Eq", "Ref", "Tab", "Suppl",
|
| 165 |
+
]
|
| 166 |
+
new_tokens.extend(sci_abbrevs)
|
| 167 |
+
|
| 168 |
+
# Chemical element symbols
|
| 169 |
+
elements = [
|
| 170 |
+
"H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne",
|
| 171 |
+
"Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar",
|
| 172 |
+
"K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn",
|
| 173 |
+
"Ga", "Ge", "As", "Se", "Br", "Kr",
|
| 174 |
+
"Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd",
|
| 175 |
+
"In", "Sn", "Sb", "Te", "I", "Xe",
|
| 176 |
+
"Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb",
|
| 177 |
+
"Dy", "Ho", "Er", "Tm", "Yb", "Lu",
|
| 178 |
+
"Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb",
|
| 179 |
+
"Bi", "Po", "At", "Rn",
|
| 180 |
+
"Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk",
|
| 181 |
+
"Cf", "Es", "Fm", "Md", "No", "Lr",
|
| 182 |
+
"Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh",
|
| 183 |
+
"Fl", "Mc", "Lv", "Ts", "Og",
|
| 184 |
+
]
|
| 185 |
+
new_tokens.extend(elements)
|
| 186 |
+
|
| 187 |
+
# Amino acid single-letter codes
|
| 188 |
+
amino_acids = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I",
|
| 189 |
+
"L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"]
|
| 190 |
+
new_tokens.extend(amino_acids)
|
| 191 |
+
|
| 192 |
+
# Mathematical operators (Unicode)
|
| 193 |
+
math_ops = [
|
| 194 |
+
"±", "∓", "×", "÷", "∈", "∉", "∋", "∏", "∑", "∧", "∨", "¬",
|
| 195 |
+
"≤", "≥", "≠", "≈", "≡", "≅", "≆", "≇", "≉", "≊", "≋",
|
| 196 |
+
"⊂", "⊃", "⊆", "⊇", "⊄", "⊅", "⊈", "⊉",
|
| 197 |
+
"∞", "∂", "∇", "√", "∛", "∜",
|
| 198 |
+
"∫", "∬", "∭", "∮", "∯", "∰",
|
| 199 |
+
"∴", "∵", "∶", "∷", "∼", "∽", "≈", "≋",
|
| 200 |
+
"⟨", "⟩", "|", "‖", "‵", "′", "″", "‴",
|
| 201 |
+
"•", "·", "‣", "⁂", "※", "‼", "⁇", "⁈",
|
| 202 |
+
]
|
| 203 |
+
new_tokens.extend(math_ops)
|
| 204 |
+
|
| 205 |
+
# Add tokens that aren't already in vocabulary
|
| 206 |
+
for token in new_tokens:
|
| 207 |
+
if token not in current_vocab:
|
| 208 |
+
self.tokenizer.add_tokens([token])
|
| 209 |
+
|
| 210 |
+
print(f"Extended vocabulary with {len(new_tokens)} science tokens")
|
| 211 |
+
print(f"Final vocabulary size: {self.tokenizer.get_vocab_size()}")
|
| 212 |
+
|
| 213 |
+
def save(self, path: str):
|
| 214 |
+
"""Save tokenizer to disk."""
|
| 215 |
+
self.tokenizer.save(path)
|
| 216 |
+
print(f"Tokenizer saved to {path}")
|
| 217 |
+
|
| 218 |
+
def encode(
|
| 219 |
+
self,
|
| 220 |
+
text: str,
|
| 221 |
+
add_special_tokens: bool = True,
|
| 222 |
+
return_tensors: str = "pt",
|
| 223 |
+
) -> Union[Dict, torch.Tensor]:
|
| 224 |
+
"""
|
| 225 |
+
Encode text to token IDs.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
text: Input text
|
| 229 |
+
add_special_tokens: Add BOS/EOS tokens
|
| 230 |
+
return_tensors: "pt" for PyTorch tensors, "np" for numpy, None for list
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Dictionary with input_ids and attention_mask, or tensors/list
|
| 234 |
+
"""
|
| 235 |
+
encoding = self.tokenizer.encode(text, add_special_tokens=add_special_tokens)
|
| 236 |
+
|
| 237 |
+
result = {
|
| 238 |
+
"input_ids": encoding.ids,
|
| 239 |
+
"attention_mask": encoding.attention_mask,
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
if return_tensors == "pt":
|
| 243 |
+
result = {k: torch.tensor(v).unsqueeze(0) for k, v in result.items()}
|
| 244 |
+
elif return_tensors == "np":
|
| 245 |
+
import numpy as np
|
| 246 |
+
result = {k: np.array(v) for k, v in result.items()}
|
| 247 |
+
|
| 248 |
+
return result
|
| 249 |
+
|
| 250 |
+
def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
|
| 251 |
+
"""Decode token IDs back to text."""
|
| 252 |
+
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
| 253 |
+
|
| 254 |
+
def batch_encode(
|
| 255 |
+
self,
|
| 256 |
+
texts: List[str],
|
| 257 |
+
padding: bool = True,
|
| 258 |
+
truncation: bool = True,
|
| 259 |
+
max_length: Optional[int] = None,
|
| 260 |
+
return_tensors: str = "pt",
|
| 261 |
+
) -> Dict:
|
| 262 |
+
"""
|
| 263 |
+
Encode a batch of texts.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
texts: List of input texts
|
| 267 |
+
padding: Pad to same length
|
| 268 |
+
truncation: Truncate to max_length
|
| 269 |
+
max_length: Maximum sequence length
|
| 270 |
+
return_tensors: Tensor format
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
Batch encoded dictionary
|
| 274 |
+
"""
|
| 275 |
+
if max_length is None:
|
| 276 |
+
max_length = self.config.get("max_seq_len", 16384)
|
| 277 |
+
|
| 278 |
+
encodings = self.tokenizer.encode_batch(
|
| 279 |
+
texts,
|
| 280 |
+
add_special_tokens=True,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Manual padding/truncation
|
| 284 |
+
input_ids = []
|
| 285 |
+
attention_masks = []
|
| 286 |
+
|
| 287 |
+
for enc in encodings:
|
| 288 |
+
ids = enc.ids
|
| 289 |
+
mask = enc.attention_mask
|
| 290 |
+
|
| 291 |
+
if truncation and len(ids) > max_length:
|
| 292 |
+
ids = ids[:max_length]
|
| 293 |
+
mask = mask[:max_length]
|
| 294 |
+
|
| 295 |
+
input_ids.append(ids)
|
| 296 |
+
attention_masks.append(mask)
|
| 297 |
+
|
| 298 |
+
# Pad to same length if requested
|
| 299 |
+
if padding:
|
| 300 |
+
max_len = max(len(ids) for ids in input_ids)
|
| 301 |
+
padded_ids = []
|
| 302 |
+
padded_masks = []
|
| 303 |
+
|
| 304 |
+
for ids, mask in zip(input_ids, attention_masks):
|
| 305 |
+
pad_len = max_len - len(ids)
|
| 306 |
+
padded_ids.append(ids + [self.special_tokens["[PAD]"]] * pad_len)
|
| 307 |
+
padded_masks.append(mask + [0] * pad_len)
|
| 308 |
+
|
| 309 |
+
input_ids = padded_ids
|
| 310 |
+
attention_masks = padded_masks
|
| 311 |
+
|
| 312 |
+
result = {
|
| 313 |
+
"input_ids": input_ids,
|
| 314 |
+
"attention_mask": attention_masks,
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
if return_tensors == "pt":
|
| 318 |
+
result = {k: torch.tensor(v) for k, v in result.items()}
|
| 319 |
+
|
| 320 |
+
return result
|
| 321 |
+
|
| 322 |
+
@property
|
| 323 |
+
def vocab_size(self) -> int:
|
| 324 |
+
"""Get vocabulary size."""
|
| 325 |
+
return self.tokenizer.get_vocab_size()
|
| 326 |
+
|
| 327 |
+
def get_vocab(self) -> Dict[str, int]:
|
| 328 |
+
"""Get vocabulary dictionary."""
|
| 329 |
+
return self.tokenizer.get_vocab()
|
| 330 |
+
|
| 331 |
+
def token_to_id(self, token: str) -> int:
|
| 332 |
+
"""Convert token to ID."""
|
| 333 |
+
return self.tokenizer.token_to_id(token)
|
| 334 |
+
|
| 335 |
+
def id_to_token(self, id: int) -> str:
|
| 336 |
+
"""Convert ID to token."""
|
| 337 |
+
return self.tokenizer.id_to_token(id)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def build_science_vocabulary_file(output_path: str):
|
| 341 |
+
"""
|
| 342 |
+
Build a science vocabulary text file for BPE training.
|
| 343 |
+
This file contains seed vocabulary terms to ensure science tokens are present.
|
| 344 |
+
"""
|
| 345 |
+
science_terms = []
|
| 346 |
+
|
| 347 |
+
# LaTeX commands
|
| 348 |
+
latex_terms = [
|
| 349 |
+
"\\alpha", "\\beta", "\\gamma", "\\delta", "\\epsilon", "\\zeta",
|
| 350 |
+
"\\eta", "\\theta", "\\iota", "\\kappa", "\\lambda", "\\mu",
|
| 351 |
+
"\\nu", "\\xi", "\\pi", "\\rho", "\\sigma", "\\tau",
|
| 352 |
+
"\\upsilon", "\\phi", "\\chi", "\\psi", "\\omega",
|
| 353 |
+
"\\sum", "\\prod", "\\int", "\\partial", "\\nabla", "\\infty",
|
| 354 |
+
"\\frac", "\\sqrt", "\\binom", "\\begin", "\\end",
|
| 355 |
+
"\\mathbf", "\\mathcal", "\\mathrm", "\\mathbb",
|
| 356 |
+
"\\in", "\\subset", "\\cup", "\\cap", "\\forall", "\\exists",
|
| 357 |
+
"\\rightarrow", "\\leftarrow", "\\Rightarrow", "\\Leftarrow",
|
| 358 |
+
"\\leq", "\\geq", "\\neq", "\\approx", "\\equiv",
|
| 359 |
+
]
|
| 360 |
+
science_terms.extend(latex_terms)
|
| 361 |
+
|
| 362 |
+
# Chemical formulas
|
| 363 |
+
chem_formulas = [
|
| 364 |
+
"H2O", "CO2", "O2", "N2", "H2", "CH4", "C2H6", "C3H8",
|
| 365 |
+
"C6H12O6", "C12H22O11", "HCl", "H2SO4", "HNO3", "H3PO4",
|
| 366 |
+
"NaOH", "KOH", "CaCO3", "NaCl", "KCl", "MgCl2",
|
| 367 |
+
"Fe2O3", "Fe3O4", "CuO", "Cu2O", "ZnO", "Al2O3",
|
| 368 |
+
"SiO2", "TiO2", "MnO2", "NH3", "NO", "NO2", "N2O",
|
| 369 |
+
"SO2", "SO3", "CO", "CH3COOH", "C2H5OH",
|
| 370 |
+
]
|
| 371 |
+
science_terms.extend(chem_formulas)
|
| 372 |
+
|
| 373 |
+
# Mathematical expressions
|
| 374 |
+
math_exprs = [
|
| 375 |
+
"x^2", "x^3", "e^x", "ln(x)", "log(x)", "sin(x)", "cos(x)",
|
| 376 |
+
"tan(x)", "arcsin(x)", "arccos(x)", "arctan(x)",
|
| 377 |
+
"f(x)", "g(x)", "h(x)", "F(x)", "G(x)",
|
| 378 |
+
"dx", "dy", "dz", "dt", "∂x", "∂y", "∂z",
|
| 379 |
+
"∫", "∬", "∭", "∮", "∑_{i=1}^{n}", "∏_{i=1}^{n}",
|
| 380 |
+
]
|
| 381 |
+
science_terms.extend(math_exprs)
|
| 382 |
+
|
| 383 |
+
# Units with numbers
|
| 384 |
+
unit_exprs = [
|
| 385 |
+
"10^6", "10^9", "10^12", "10^15", "10^18",
|
| 386 |
+
"10^-3", "10^-6", "10^-9", "10^-12",
|
| 387 |
+
"m/s", "km/h", "cm/s", "mm/s",
|
| 388 |
+
"J/mol", "kJ/mol", "cal", "kcal",
|
| 389 |
+
"eV", "MeV", "GeV", "TeV",
|
| 390 |
+
"Hz", "kHz", "MHz", "GHz",
|
| 391 |
+
"Pa", "kPa", "MPa", "GPa",
|
| 392 |
+
"°C", "K", "°F",
|
| 393 |
+
]
|
| 394 |
+
science_terms.extend(unit_exprs)
|
| 395 |
+
|
| 396 |
+
# Write to file
|
| 397 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 398 |
+
for term in science_terms:
|
| 399 |
+
f.write(term + "\n")
|
| 400 |
+
|
| 401 |
+
print(f"Science vocabulary seed file written to {output_path}")
|
| 402 |
+
print(f"Total seed terms: {len(science_terms)}")
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
if __name__ == "__main__":
|
| 406 |
+
# Example usage
|
| 407 |
+
import sys
|
| 408 |
+
|
| 409 |
+
if len(sys.argv) < 2:
|
| 410 |
+
print("Usage: python vortex_tokenizer.py <train_data.txt> [output_dir]")
|
| 411 |
+
sys.exit(1)
|
| 412 |
+
|
| 413 |
+
train_data = sys.argv[1]
|
| 414 |
+
output_dir = sys.argv[2] if len(sys.argv) > 2 else "."
|
| 415 |
+
|
| 416 |
+
# Load config (simplified for standalone)
|
| 417 |
+
config = {
|
| 418 |
+
"special_tokens": {
|
| 419 |
+
"[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3,
|
| 420 |
+
"[EQUATION]": 4, "[/EQUATION]": 5,
|
| 421 |
+
"[CITATION]": 6, "[/CITATION]": 7,
|
| 422 |
+
"[MOLECULE]": 8, "[/MOLECULE]": 9,
|
| 423 |
+
"[FIGURE]": 10, "[TABLE]": 11,
|
| 424 |
+
"[MATH]": 12, "[CHEM]": 13, "[BIO]": 14,
|
| 425 |
+
"[PHYS]": 15, "[EARTH]": 16, "[SPACE]": 17, "[ZOO]": 18,
|
| 426 |
+
},
|
| 427 |
+
"domain_tags": ["[MATH]", "[CHEM]", "[BIO]", "[PHYS]", "[EARTH]", "[SPACE]", "[ZOO]"],
|
| 428 |
+
"max_seq_len": 16384,
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
# Build seed vocabulary
|
| 432 |
+
seed_vocab_path = os.path.join(output_dir, "science_seed_vocab.txt")
|
| 433 |
+
build_science_vocabulary_file(seed_vocab_path)
|
| 434 |
+
|
| 435 |
+
# Initialize and train tokenizer
|
| 436 |
+
tokenizer = VortexScienceTokenizer(config)
|
| 437 |
+
tokenizer.train([train_data])
|
| 438 |
+
|
| 439 |
+
# Save tokenizer
|
| 440 |
+
tokenizer_path = os.path.join(output_dir, "vortex_tokenizer.json")
|
| 441 |
+
tokenizer.save(tokenizer_path)
|
| 442 |
+
print(f"Tokenizer saved to {tokenizer_path}")
|
train.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Main training entry point for Vortex models.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| 13 |
+
from configs.vortex_13b_config import VORTEX_13B_CONFIG
|
| 14 |
+
from configs.training_config import TRAINING_CONFIG, TRAINING_CONFIG_7B_CUDA, TRAINING_CONFIG_13B_CUDA, TRAINING_CONFIG_MPS
|
| 15 |
+
|
| 16 |
+
from models.vortex_model import VortexModel
|
| 17 |
+
from tokenizer.vortex_tokenizer import VortexScienceTokenizer
|
| 18 |
+
from training.trainer import VortexTrainer, VortexDataset
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def parse_args():
|
| 22 |
+
parser = argparse.ArgumentParser(description="Train Vortex scientific language model")
|
| 23 |
+
parser.add_argument("--model_size", type=str, choices=["7b", "13b"], default="7b",
|
| 24 |
+
help="Model size to train")
|
| 25 |
+
parser.add_argument("--device", type=str, default="cuda",
|
| 26 |
+
choices=["cuda", "mps", "cpu"],
|
| 27 |
+
help="Device to train on")
|
| 28 |
+
parser.add_argument("--use_mps", action="store_true",
|
| 29 |
+
help="Use MPS backend (Apple Silicon)")
|
| 30 |
+
parser.add_argument("--data_dir", type=str, default="./data/processed",
|
| 31 |
+
help="Directory with processed data shards")
|
| 32 |
+
parser.add_argument("--tokenizer_path", type=str, default=None,
|
| 33 |
+
help="Path to pretrained tokenizer")
|
| 34 |
+
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
|
| 35 |
+
help="Resume training from checkpoint")
|
| 36 |
+
parser.add_argument("--output_dir", type=str, default="./checkpoints",
|
| 37 |
+
help="Output directory for checkpoints")
|
| 38 |
+
parser.add_argument("--max_steps", type=int, default=None,
|
| 39 |
+
help="Override max training steps")
|
| 40 |
+
parser.add_argument("--micro_batch_size", type=int, default=None,
|
| 41 |
+
help="Override micro batch size")
|
| 42 |
+
parser.add_argument("--quantization", type=str, choices=[None, "int8", "int4"], default=None,
|
| 43 |
+
help="Quantization for 13B on 8GB")
|
| 44 |
+
return parser.parse_args()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def main():
|
| 48 |
+
args = parse_args()
|
| 49 |
+
|
| 50 |
+
# Load configs
|
| 51 |
+
if args.model_size == "7b":
|
| 52 |
+
model_config = VORTEX_7B_CONFIG.copy()
|
| 53 |
+
train_config = TRAINING_CONFIG_7B_CUDA.copy()
|
| 54 |
+
else:
|
| 55 |
+
model_config = VORTEX_13B_CONFIG.copy()
|
| 56 |
+
train_config = TRAINING_CONFIG_13B_CUDA.copy()
|
| 57 |
+
|
| 58 |
+
# Override with MPS config if needed
|
| 59 |
+
if args.use_mps or args.device == "mps":
|
| 60 |
+
train_config = TRAINING_CONFIG_MPS.copy()
|
| 61 |
+
train_config["use_mps"] = True
|
| 62 |
+
|
| 63 |
+
# Apply overrides
|
| 64 |
+
if args.max_steps:
|
| 65 |
+
train_config["max_steps"] = args.max_steps
|
| 66 |
+
if args.micro_batch_size:
|
| 67 |
+
train_config["micro_batch_size"] = args.micro_batch_size
|
| 68 |
+
if args.quantization:
|
| 69 |
+
train_config["quantization"] = args.quantization
|
| 70 |
+
|
| 71 |
+
# Set device
|
| 72 |
+
device = torch.device(args.device)
|
| 73 |
+
train_config["device"] = args.device
|
| 74 |
+
|
| 75 |
+
print(f"Training Vortex-{args.model_size.upper()}")
|
| 76 |
+
print(f"Device: {device}")
|
| 77 |
+
print(f"Max steps: {train_config['max_steps']}")
|
| 78 |
+
print(f"Micro batch size: {train_config['micro_batch_size']}")
|
| 79 |
+
|
| 80 |
+
# Create tokenizer
|
| 81 |
+
print("Loading tokenizer...")
|
| 82 |
+
tokenizer = VortexScienceTokenizer(
|
| 83 |
+
model_config,
|
| 84 |
+
tokenizer_path=args.tokenizer_path,
|
| 85 |
+
)
|
| 86 |
+
print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
|
| 87 |
+
|
| 88 |
+
# Create model
|
| 89 |
+
print("Creating model...")
|
| 90 |
+
model = VortexModel(model_config)
|
| 91 |
+
print(f"Model parameters: {model.get_num_params():,}")
|
| 92 |
+
|
| 93 |
+
# Estimate memory
|
| 94 |
+
mem = model.estimate_memory_usage(
|
| 95 |
+
train_config["micro_batch_size"],
|
| 96 |
+
model_config["max_seq_len"],
|
| 97 |
+
)
|
| 98 |
+
print("Memory estimate:")
|
| 99 |
+
for k, v in mem.items():
|
| 100 |
+
print(f" {k}: {v:.2f} GB")
|
| 101 |
+
|
| 102 |
+
# Load dataset
|
| 103 |
+
print("Loading dataset...")
|
| 104 |
+
data_dir = Path(args.data_dir)
|
| 105 |
+
shard_files = sorted(list(data_dir.glob("train_*.parquet")))
|
| 106 |
+
if not shard_files:
|
| 107 |
+
print(f"No training shards found in {data_dir}")
|
| 108 |
+
print("Please run data pipeline first.")
|
| 109 |
+
sys.exit(1)
|
| 110 |
+
|
| 111 |
+
train_dataset = VortexDataset(
|
| 112 |
+
shard_files,
|
| 113 |
+
tokenizer,
|
| 114 |
+
max_seq_len=model_config["max_seq_len"],
|
| 115 |
+
)
|
| 116 |
+
print(f"Training dataset size: {len(train_dataset)} samples")
|
| 117 |
+
|
| 118 |
+
# Create eval dataset (use first few shards)
|
| 119 |
+
eval_shard_files = shard_files[:1] # Use first shard for eval
|
| 120 |
+
eval_dataset = VortexDataset(
|
| 121 |
+
eval_shard_files,
|
| 122 |
+
tokenizer,
|
| 123 |
+
max_seq_len=model_config["max_seq_len"],
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Create trainer
|
| 127 |
+
trainer = VortexTrainer(
|
| 128 |
+
model=model,
|
| 129 |
+
tokenizer=tokenizer,
|
| 130 |
+
train_dataset=train_dataset,
|
| 131 |
+
config=train_config,
|
| 132 |
+
eval_dataset=eval_dataset,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Resume from checkpoint if specified
|
| 136 |
+
if args.resume_from_checkpoint:
|
| 137 |
+
trainer.load_checkpoint(args.resume_from_checkpoint)
|
| 138 |
+
|
| 139 |
+
# Train
|
| 140 |
+
trainer.train()
|
| 141 |
+
|
| 142 |
+
print("Training complete!")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
main()
|
training/__pycache__/curriculum.cpython-313.pyc
ADDED
|
Binary file (5 kB). View file
|
|
|
training/__pycache__/losses.cpython-313.pyc
ADDED
|
Binary file (5.84 kB). View file
|
|
|
training/curriculum.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Curriculum learning for Vortex model.
|
| 3 |
+
Progresses through stages: Foundation → Domain → Reasoning → Integration.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import List, Dict, Optional
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CurriculumScheduler:
|
| 11 |
+
"""
|
| 12 |
+
Schedules curriculum stages during training.
|
| 13 |
+
Each stage has a start and end fraction of total training steps.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
STAGES = ["foundation", "domain", "reasoning", "integration"]
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
config: Dict,
|
| 21 |
+
total_steps: int,
|
| 22 |
+
):
|
| 23 |
+
"""
|
| 24 |
+
Initialize curriculum scheduler.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
config: Training config with curriculum_stages
|
| 28 |
+
total_steps: Total number of training steps
|
| 29 |
+
"""
|
| 30 |
+
self.config = config
|
| 31 |
+
self.total_steps = total_steps
|
| 32 |
+
self.stages = config.get("curriculum_stages", [
|
| 33 |
+
{"name": "foundation", "start": 0.0, "end": 0.2},
|
| 34 |
+
{"name": "domain", "start": 0.2, "end": 0.5},
|
| 35 |
+
{"name": "reasoning", "start": 0.5, "end": 0.8},
|
| 36 |
+
{"name": "integration", "start": 0.8, "end": 1.0},
|
| 37 |
+
])
|
| 38 |
+
|
| 39 |
+
# Convert fractions to step numbers
|
| 40 |
+
for stage in self.stages:
|
| 41 |
+
stage["start_step"] = int(stage["start"] * total_steps)
|
| 42 |
+
stage["end_step"] = int(stage["end"] * total_steps)
|
| 43 |
+
|
| 44 |
+
def get_stage(
|
| 45 |
+
self,
|
| 46 |
+
current_step: int,
|
| 47 |
+
) -> Optional[Dict]:
|
| 48 |
+
"""
|
| 49 |
+
Get current curriculum stage.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
current_step: Current training step
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Stage dictionary or None if training complete
|
| 56 |
+
"""
|
| 57 |
+
for stage in self.stages:
|
| 58 |
+
if stage["start_step"] <= current_step < stage["end_step"]:
|
| 59 |
+
return stage
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
def get_stage_name(self, current_step: int) -> str:
|
| 63 |
+
"""Get current stage name."""
|
| 64 |
+
stage = self.get_stage(current_step)
|
| 65 |
+
return stage["name"] if stage else "complete"
|
| 66 |
+
|
| 67 |
+
def get_stage_weight(
|
| 68 |
+
self,
|
| 69 |
+
current_step: int,
|
| 70 |
+
base_weight: float,
|
| 71 |
+
) -> float:
|
| 72 |
+
"""
|
| 73 |
+
Get weight for a curriculum component based on stage.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
current_step: Current training step
|
| 77 |
+
base_weight: Base weight for the component
|
| 78 |
+
Returns:
|
| 79 |
+
Adjusted weight (can be 0 if component not active in current stage)
|
| 80 |
+
"""
|
| 81 |
+
stage = self.get_stage(current_step)
|
| 82 |
+
if not stage:
|
| 83 |
+
return 0.0
|
| 84 |
+
|
| 85 |
+
stage_name = stage["name"]
|
| 86 |
+
|
| 87 |
+
# Define which components are active in each stage
|
| 88 |
+
stage_components = {
|
| 89 |
+
"foundation": ["lm_loss"], # Only language modeling
|
| 90 |
+
"domain": ["lm_loss", "equation_loss", "domain_loss"],
|
| 91 |
+
"reasoning": ["lm_loss", "equation_loss", "domain_loss", "citation_loss"],
|
| 92 |
+
"integration": ["lm_loss", "equation_loss", "domain_loss", "citation_loss", "numerical_loss"],
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
active_components = stage_components.get(stage_name, ["lm_loss"])
|
| 96 |
+
|
| 97 |
+
# Return base weight if component active, else 0
|
| 98 |
+
# (Caller checks if their component is in active_components)
|
| 99 |
+
return base_weight if "lm_loss" in active_components else 0.0
|
| 100 |
+
|
| 101 |
+
def get_dataset_sampler(
|
| 102 |
+
self,
|
| 103 |
+
current_step: int,
|
| 104 |
+
):
|
| 105 |
+
"""
|
| 106 |
+
Get dataset sampler for current stage.
|
| 107 |
+
Different stages may mix datasets differently.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Sampler weights for different datasets
|
| 111 |
+
"""
|
| 112 |
+
stage = self.get_stage(current_step)
|
| 113 |
+
if not stage:
|
| 114 |
+
return None
|
| 115 |
+
|
| 116 |
+
stage_name = stage["name"]
|
| 117 |
+
|
| 118 |
+
# Dataset mixing proportions per stage
|
| 119 |
+
mixing_proportions = {
|
| 120 |
+
"foundation": {
|
| 121 |
+
"pile_scientific": 0.3,
|
| 122 |
+
"s2orc": 0.3,
|
| 123 |
+
"automath": 0.2,
|
| 124 |
+
"pubmed_qa": 0.2,
|
| 125 |
+
},
|
| 126 |
+
"domain": {
|
| 127 |
+
"pile_scientific": 0.2,
|
| 128 |
+
"s2orc": 0.2,
|
| 129 |
+
"automath": 0.2,
|
| 130 |
+
"pubmed_qa": 0.2,
|
| 131 |
+
"deepmind_math": 0.2,
|
| 132 |
+
},
|
| 133 |
+
"reasoning": {
|
| 134 |
+
"pile_scientific": 0.15,
|
| 135 |
+
"s2orc": 0.15,
|
| 136 |
+
"automath": 0.3,
|
| 137 |
+
"deepmind_math": 0.3,
|
| 138 |
+
"pubmed_qa": 0.1,
|
| 139 |
+
},
|
| 140 |
+
"integration": {
|
| 141 |
+
"pile_scientific": 0.2,
|
| 142 |
+
"s2orc": 0.2,
|
| 143 |
+
"automath": 0.2,
|
| 144 |
+
"deepmind_math": 0.2,
|
| 145 |
+
"pubmed_qa": 0.2,
|
| 146 |
+
},
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
return mixing_proportions.get(stage_name, {"pile_scientific": 1.0})
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def test_curriculum():
|
| 153 |
+
"""Test curriculum scheduler."""
|
| 154 |
+
config = {
|
| 155 |
+
"curriculum_stages": [
|
| 156 |
+
{"name": "foundation", "start": 0.0, "end": 0.2},
|
| 157 |
+
{"name": "domain", "start": 0.2, "end": 0.5},
|
| 158 |
+
{"name": "reasoning", "start": 0.5, "end": 0.8},
|
| 159 |
+
{"name": "integration", "start": 0.8, "end": 1.0},
|
| 160 |
+
]
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
total_steps = 1000
|
| 164 |
+
scheduler = CurriculumScheduler(config, total_steps)
|
| 165 |
+
|
| 166 |
+
for step in [0, 100, 250, 500, 750, 999]:
|
| 167 |
+
stage = scheduler.get_stage(step)
|
| 168 |
+
name = scheduler.get_stage_name(step)
|
| 169 |
+
print(f"Step {step}: {name}")
|
| 170 |
+
|
| 171 |
+
print("Curriculum test passed!")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
test_curriculum()
|
training/losses.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Science-aware losses for Vortex model training.
|
| 3 |
+
Combines standard language modeling with auxiliary tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from typing import Dict, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VortexLoss(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Combined loss for Vortex model with science-aware components.
|
| 15 |
+
total_loss = (
|
| 16 |
+
lm_loss * 1.0
|
| 17 |
+
+ equation_loss * 0.3
|
| 18 |
+
+ domain_loss * 0.1
|
| 19 |
+
+ citation_loss * 0.1
|
| 20 |
+
+ numerical_loss * 0.2
|
| 21 |
+
)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, config: Dict):
|
| 25 |
+
"""
|
| 26 |
+
Initialize loss.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
config: Training config with loss_weights
|
| 30 |
+
"""
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.loss_weights = config.get("loss_weights", {
|
| 33 |
+
"lm_loss": 1.0,
|
| 34 |
+
"equation_loss": 0.3,
|
| 35 |
+
"domain_loss": 0.1,
|
| 36 |
+
"citation_loss": 0.1,
|
| 37 |
+
"numerical_loss": 0.2,
|
| 38 |
+
})
|
| 39 |
+
|
| 40 |
+
def forward(
|
| 41 |
+
self,
|
| 42 |
+
logits: torch.Tensor,
|
| 43 |
+
labels: torch.Tensor,
|
| 44 |
+
equation_module: Optional[nn.Module] = None,
|
| 45 |
+
equation_mask: Optional[torch.Tensor] = None,
|
| 46 |
+
domain_logits: Optional[torch.Tensor] = None,
|
| 47 |
+
domain_labels: Optional[torch.Tensor] = None,
|
| 48 |
+
citation_module: Optional[nn.Module] = None,
|
| 49 |
+
citation_mask: Optional[torch.Tensor] = None,
|
| 50 |
+
citation_confidence: Optional[torch.Tensor] = None,
|
| 51 |
+
numerical_module: Optional[nn.Module] = None,
|
| 52 |
+
numerical_mask: Optional[torch.Tensor] = None,
|
| 53 |
+
) -> Dict[str, torch.Tensor]:
|
| 54 |
+
"""
|
| 55 |
+
Compute total loss.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
logits: (batch, seq_len, vocab_size)
|
| 59 |
+
labels: (batch, seq_len) with token IDs
|
| 60 |
+
equation_module: EquationModule for equation loss
|
| 61 |
+
equation_mask: (batch, seq_len) 1 if token in equation
|
| 62 |
+
domain_logits: (batch, num_domains)
|
| 63 |
+
domain_labels: (batch,)
|
| 64 |
+
citation_module: CitationModule for citation loss
|
| 65 |
+
citation_mask: (batch, seq_len)
|
| 66 |
+
citation_confidence: (batch, seq_len, 1)
|
| 67 |
+
numerical_module: NumericalReasoningModule
|
| 68 |
+
numerical_mask: (batch, seq_len)
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Dictionary with total loss and component losses
|
| 72 |
+
"""
|
| 73 |
+
losses = {}
|
| 74 |
+
|
| 75 |
+
# 1. Language modeling loss (next token prediction)
|
| 76 |
+
lm_loss = F.cross_entropy(
|
| 77 |
+
logits.view(-1, logits.size(-1)),
|
| 78 |
+
labels.view(-1),
|
| 79 |
+
ignore_index=-100, # ignore padding
|
| 80 |
+
)
|
| 81 |
+
losses["lm_loss"] = lm_loss
|
| 82 |
+
|
| 83 |
+
# 2. Equation detection loss
|
| 84 |
+
if equation_module is not None and equation_mask is not None:
|
| 85 |
+
# Need hidden states from equation module - would need to modify forward pass
|
| 86 |
+
# For now, placeholder
|
| 87 |
+
equation_loss = torch.tensor(0.0, device=logits.device)
|
| 88 |
+
losses["equation_loss"] = equation_loss
|
| 89 |
+
else:
|
| 90 |
+
losses["equation_loss"] = torch.tensor(0.0, device=logits.device)
|
| 91 |
+
|
| 92 |
+
# 3. Domain classification loss
|
| 93 |
+
if domain_logits is not None and domain_labels is not None:
|
| 94 |
+
domain_loss = F.cross_entropy(domain_logits, domain_labels)
|
| 95 |
+
losses["domain_loss"] = domain_loss
|
| 96 |
+
else:
|
| 97 |
+
losses["domain_loss"] = torch.tensor(0.0, device=logits.device)
|
| 98 |
+
|
| 99 |
+
# 4. Citation detection loss
|
| 100 |
+
if citation_module is not None and citation_mask is not None and citation_confidence is not None:
|
| 101 |
+
citation_loss = citation_module.compute_citation_loss(
|
| 102 |
+
# Would need hidden states - placeholder
|
| 103 |
+
torch.zeros_like(logits[:, :, :1]), # dummy
|
| 104 |
+
citation_mask,
|
| 105 |
+
citation_confidence,
|
| 106 |
+
)
|
| 107 |
+
losses["citation_loss"] = citation_loss
|
| 108 |
+
else:
|
| 109 |
+
losses["citation_loss"] = torch.tensor(0.0, device=logits.device)
|
| 110 |
+
|
| 111 |
+
# 5. Numerical reasoning loss
|
| 112 |
+
if numerical_module is not None and numerical_mask is not None:
|
| 113 |
+
numerical_loss = numerical_module.compute_numerical_loss(
|
| 114 |
+
torch.zeros_like(logits), # dummy hidden states
|
| 115 |
+
numerical_mask,
|
| 116 |
+
None, # target values
|
| 117 |
+
)
|
| 118 |
+
losses["numerical_loss"] = numerical_loss
|
| 119 |
+
else:
|
| 120 |
+
losses["numerical_loss"] = torch.tensor(0.0, device=logits.device)
|
| 121 |
+
|
| 122 |
+
# Weighted sum
|
| 123 |
+
total_loss = torch.tensor(0.0, device=logits.device)
|
| 124 |
+
for name, loss in losses.items():
|
| 125 |
+
weight = self.loss_weights.get(name, 1.0)
|
| 126 |
+
total_loss = total_loss + loss * weight
|
| 127 |
+
|
| 128 |
+
losses["total_loss"] = total_loss
|
| 129 |
+
|
| 130 |
+
return losses
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def test_vortex_loss():
|
| 134 |
+
"""Test the loss function."""
|
| 135 |
+
config = {"loss_weights": {
|
| 136 |
+
"lm_loss": 1.0,
|
| 137 |
+
"equation_loss": 0.3,
|
| 138 |
+
"domain_loss": 0.1,
|
| 139 |
+
"citation_loss": 0.1,
|
| 140 |
+
"numerical_loss": 0.2,
|
| 141 |
+
}}
|
| 142 |
+
|
| 143 |
+
loss_fn = VortexLoss(config)
|
| 144 |
+
|
| 145 |
+
batch_size = 2
|
| 146 |
+
seq_len = 128
|
| 147 |
+
vocab_size = 1000
|
| 148 |
+
|
| 149 |
+
logits = torch.randn(batch_size, seq_len, vocab_size)
|
| 150 |
+
labels = torch.randint(0, vocab_size, (batch_size, seq_len))
|
| 151 |
+
|
| 152 |
+
losses = loss_fn(logits, labels)
|
| 153 |
+
print("Losses:")
|
| 154 |
+
for name, value in losses.items():
|
| 155 |
+
print(f" {name}: {value.item():.4f}")
|
| 156 |
+
|
| 157 |
+
assert "total_loss" in losses
|
| 158 |
+
print("VortexLoss test passed!")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
test_vortex_loss()
|
training/trainer.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Trainer: Main training loop for Vortex model.
|
| 3 |
+
Handles gradient accumulation, mixed precision, checkpointing.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch.utils.data import DataLoader, Dataset
|
| 11 |
+
from typing import Optional, Dict, List, Callable
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
from ..training.losses import VortexLoss
|
| 16 |
+
from ..training.curriculum import CurriculumScheduler
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class VortexDataset(Dataset):
|
| 20 |
+
"""Simple dataset wrapper."""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
shard_files: List[str],
|
| 25 |
+
tokenizer,
|
| 26 |
+
max_seq_len: int = 16384,
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Initialize dataset.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
shard_files: List of parquet shard files
|
| 33 |
+
tokenizer: Tokenizer for encoding text
|
| 34 |
+
max_seq_len: Maximum sequence length
|
| 35 |
+
"""
|
| 36 |
+
self.shard_files = shard_files
|
| 37 |
+
self.tokenizer = tokenizer
|
| 38 |
+
self.max_seq_len = max_seq_len
|
| 39 |
+
|
| 40 |
+
# Load all shards into memory (for simplicity - would stream in practice)
|
| 41 |
+
self.samples = []
|
| 42 |
+
self._load_shards()
|
| 43 |
+
|
| 44 |
+
def _load_shards(self):
|
| 45 |
+
"""Load all shards."""
|
| 46 |
+
import pandas as pd
|
| 47 |
+
|
| 48 |
+
for shard in self.shard_files:
|
| 49 |
+
df = pd.read_parquet(shard)
|
| 50 |
+
for _, row in df.iterrows():
|
| 51 |
+
self.samples.append({
|
| 52 |
+
"text": row["text"],
|
| 53 |
+
"dataset": row.get("dataset", ""),
|
| 54 |
+
"domain": row.get("domain", ""),
|
| 55 |
+
})
|
| 56 |
+
|
| 57 |
+
def __len__(self) -> int:
|
| 58 |
+
return len(self.samples)
|
| 59 |
+
|
| 60 |
+
def __getitem__(self, idx) -> Dict:
|
| 61 |
+
sample = self.samples[idx]
|
| 62 |
+
text = sample["text"]
|
| 63 |
+
|
| 64 |
+
# Tokenize
|
| 65 |
+
encoding = self.tokenizer.encode(
|
| 66 |
+
text,
|
| 67 |
+
add_special_tokens=True,
|
| 68 |
+
return_tensors="pt",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
input_ids = encoding["input_ids"].squeeze(0)
|
| 72 |
+
attention_mask = encoding["attention_mask"].squeeze(0)
|
| 73 |
+
|
| 74 |
+
# Truncate if needed
|
| 75 |
+
if len(input_ids) > self.max_seq_len:
|
| 76 |
+
input_ids = input_ids[:self.max_seq_len]
|
| 77 |
+
attention_mask = attention_mask[:self.max_seq_len]
|
| 78 |
+
|
| 79 |
+
# Labels are same as input_ids (causal LM)
|
| 80 |
+
labels = input_ids.clone()
|
| 81 |
+
|
| 82 |
+
return {
|
| 83 |
+
"input_ids": input_ids,
|
| 84 |
+
"attention_mask": attention_mask,
|
| 85 |
+
"labels": labels,
|
| 86 |
+
"domain": sample["domain"],
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class VortexTrainer:
|
| 91 |
+
"""
|
| 92 |
+
Main trainer for Vortex model.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
model: nn.Module,
|
| 98 |
+
tokenizer,
|
| 99 |
+
train_dataset: Dataset,
|
| 100 |
+
config: Dict,
|
| 101 |
+
eval_dataset: Optional[Dataset] = None,
|
| 102 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
| 103 |
+
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
| 104 |
+
):
|
| 105 |
+
"""
|
| 106 |
+
Initialize trainer.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
model: VortexModel
|
| 110 |
+
tokenizer: VortexScienceTokenizer
|
| 111 |
+
train_dataset: Training dataset
|
| 112 |
+
config: Training configuration
|
| 113 |
+
eval_dataset: Optional evaluation dataset
|
| 114 |
+
optimizer: Optional optimizer (created if None)
|
| 115 |
+
scheduler: Optional LR scheduler
|
| 116 |
+
"""
|
| 117 |
+
self.model = model
|
| 118 |
+
self.tokenizer = tokenizer
|
| 119 |
+
self.train_dataset = train_dataset
|
| 120 |
+
self.eval_dataset = eval_dataset
|
| 121 |
+
self.config = config
|
| 122 |
+
|
| 123 |
+
self.device = torch.device(config["device"])
|
| 124 |
+
self.use_amp = config.get("use_amp", True)
|
| 125 |
+
self.amp_dtype = getattr(torch, config.get("amp_dtype", "bfloat16"))
|
| 126 |
+
|
| 127 |
+
# Move model to device
|
| 128 |
+
self.model.to(self.device)
|
| 129 |
+
|
| 130 |
+
# Setup optimizer
|
| 131 |
+
if optimizer is None:
|
| 132 |
+
self.optimizer = self._create_optimizer()
|
| 133 |
+
else:
|
| 134 |
+
self.optimizer = optimizer
|
| 135 |
+
|
| 136 |
+
# Setup scheduler
|
| 137 |
+
if scheduler is None:
|
| 138 |
+
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 139 |
+
self.optimizer,
|
| 140 |
+
T_max=config["max_steps"],
|
| 141 |
+
)
|
| 142 |
+
else:
|
| 143 |
+
self.scheduler = scheduler
|
| 144 |
+
|
| 145 |
+
# Setup AMP scaler
|
| 146 |
+
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp and self.device.type == "cuda" else None
|
| 147 |
+
|
| 148 |
+
# Loss function
|
| 149 |
+
self.loss_fn = VortexLoss(config)
|
| 150 |
+
|
| 151 |
+
# Curriculum scheduler
|
| 152 |
+
self.curriculum = CurriculumScheduler(config, config["max_steps"])
|
| 153 |
+
|
| 154 |
+
# Logging
|
| 155 |
+
self.log_dir = Path(config.get("log_dir", "logs"))
|
| 156 |
+
self.log_dir.mkdir(parents=True, exist_ok=True)
|
| 157 |
+
self.log_interval = config.get("log_interval", 100)
|
| 158 |
+
|
| 159 |
+
# Checkpointing
|
| 160 |
+
self.checkpoint_dir = Path(config.get("checkpoint_dir", "checkpoints"))
|
| 161 |
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 162 |
+
self.save_interval = config.get("save_interval", 5000)
|
| 163 |
+
|
| 164 |
+
# Training state
|
| 165 |
+
self.global_step = 0
|
| 166 |
+
self.best_eval_loss = float('inf')
|
| 167 |
+
|
| 168 |
+
# Data loader
|
| 169 |
+
self.train_loader = DataLoader(
|
| 170 |
+
train_dataset,
|
| 171 |
+
batch_size=config["micro_batch_size"],
|
| 172 |
+
shuffle=True,
|
| 173 |
+
num_workers=config.get("num_workers", 4),
|
| 174 |
+
pin_memory=config.get("pin_memory", True),
|
| 175 |
+
prefetch_factor=config.get("prefetch_factor", 2),
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if eval_dataset:
|
| 179 |
+
self.eval_loader = DataLoader(
|
| 180 |
+
eval_dataset,
|
| 181 |
+
batch_size=config["micro_batch_size"],
|
| 182 |
+
shuffle=False,
|
| 183 |
+
num_workers=config.get("num_workers", 4),
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def _create_optimizer(self) -> torch.optim.Optimizer:
|
| 187 |
+
"""Create AdamW optimizer."""
|
| 188 |
+
return torch.optim.AdamW(
|
| 189 |
+
self.model.parameters(),
|
| 190 |
+
lr=self.config["learning_rate"],
|
| 191 |
+
betas=(self.config["beta1"], self.config["beta2"]),
|
| 192 |
+
weight_decay=self.config["weight_decay"],
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def train_step(
|
| 196 |
+
self,
|
| 197 |
+
batch: Dict,
|
| 198 |
+
current_step: int,
|
| 199 |
+
) -> Dict[str, torch.Tensor]:
|
| 200 |
+
"""
|
| 201 |
+
Single training step.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
batch: Batch dictionary
|
| 205 |
+
current_step: Current step number
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
Dictionary of losses
|
| 209 |
+
"""
|
| 210 |
+
self.model.train()
|
| 211 |
+
|
| 212 |
+
# Move batch to device
|
| 213 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 214 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
| 215 |
+
labels = batch["labels"].to(self.device)
|
| 216 |
+
|
| 217 |
+
# Domain info (placeholder - would extract from batch)
|
| 218 |
+
domain_ids = None
|
| 219 |
+
domain_tags = None
|
| 220 |
+
|
| 221 |
+
# Forward pass with AMP
|
| 222 |
+
with torch.cuda.amp.autocast(enabled=self.use_amp and self.device.type == "cuda"):
|
| 223 |
+
outputs = self.model(
|
| 224 |
+
input_ids=input_ids,
|
| 225 |
+
attention_mask=attention_mask,
|
| 226 |
+
domain_ids=domain_ids,
|
| 227 |
+
domain_tags=domain_tags,
|
| 228 |
+
return_dict=True,
|
| 229 |
+
)
|
| 230 |
+
logits = outputs["logits"]
|
| 231 |
+
|
| 232 |
+
# Compute losses
|
| 233 |
+
losses = self.loss_fn(
|
| 234 |
+
logits=logits,
|
| 235 |
+
labels=labels,
|
| 236 |
+
# Pass modules and masks for auxiliary losses
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Backward pass
|
| 240 |
+
if self.scaler:
|
| 241 |
+
self.scaler.scale(losses["total_loss"]).backward()
|
| 242 |
+
else:
|
| 243 |
+
losses["total_loss"].backward()
|
| 244 |
+
|
| 245 |
+
return losses
|
| 246 |
+
|
| 247 |
+
def train_epoch(self):
|
| 248 |
+
"""Train for one epoch."""
|
| 249 |
+
self.model.train()
|
| 250 |
+
|
| 251 |
+
for batch_idx, batch in enumerate(self.train_loader):
|
| 252 |
+
# Train step
|
| 253 |
+
losses = self.train_step(batch, self.global_step)
|
| 254 |
+
|
| 255 |
+
# Gradient accumulation
|
| 256 |
+
if (self.global_step + 1) % self.config["gradient_accumulation_steps"] == 0:
|
| 257 |
+
# Gradient clipping
|
| 258 |
+
if self.config.get("clip_grad_norm", 0) > 0:
|
| 259 |
+
if self.scaler:
|
| 260 |
+
self.scaler.unscale_(self.optimizer)
|
| 261 |
+
torch.nn.utils.clip_grad_norm_(
|
| 262 |
+
self.model.parameters(),
|
| 263 |
+
self.config["clip_grad_norm"],
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Optimizer step
|
| 267 |
+
if self.scaler:
|
| 268 |
+
self.scaler.step(self.optimizer)
|
| 269 |
+
self.scaler.update()
|
| 270 |
+
else:
|
| 271 |
+
self.optimizer.step()
|
| 272 |
+
|
| 273 |
+
self.optimizer.zero_grad()
|
| 274 |
+
self.scheduler.step()
|
| 275 |
+
|
| 276 |
+
# Logging
|
| 277 |
+
if self.global_step % self.log_interval == 0:
|
| 278 |
+
self._log_losses(losses, batch_idx)
|
| 279 |
+
|
| 280 |
+
# Evaluation
|
| 281 |
+
if self.eval_dataset and self.global_step % self.config.get("eval_interval", 1000) == 0:
|
| 282 |
+
self.evaluate()
|
| 283 |
+
|
| 284 |
+
# Checkpointing
|
| 285 |
+
if self.global_step % self.save_interval == 0:
|
| 286 |
+
self.save_checkpoint()
|
| 287 |
+
|
| 288 |
+
self.global_step += 1
|
| 289 |
+
|
| 290 |
+
if self.global_step >= self.config["max_steps"]:
|
| 291 |
+
print("Reached max steps")
|
| 292 |
+
return
|
| 293 |
+
|
| 294 |
+
def evaluate(self) -> Dict[str, float]:
|
| 295 |
+
"""Run evaluation."""
|
| 296 |
+
self.model.eval()
|
| 297 |
+
total_loss = 0.0
|
| 298 |
+
num_batches = 0
|
| 299 |
+
|
| 300 |
+
with torch.no_grad():
|
| 301 |
+
for batch in self.eval_loader:
|
| 302 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 303 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
| 304 |
+
labels = batch["labels"].to(self.device)
|
| 305 |
+
|
| 306 |
+
with torch.cuda.amp.autocast(enabled=self.use_amp and self.device.type == "cuda"):
|
| 307 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
|
| 308 |
+
logits = outputs["logits"]
|
| 309 |
+
loss = F.cross_entropy(
|
| 310 |
+
logits.view(-1, logits.size(-1)),
|
| 311 |
+
labels.view(-1),
|
| 312 |
+
ignore_index=-100,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
total_loss += loss.item()
|
| 316 |
+
num_batches += 1
|
| 317 |
+
|
| 318 |
+
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
| 319 |
+
print(f"Evaluation at step {self.global_step}: loss = {avg_loss:.4f}")
|
| 320 |
+
|
| 321 |
+
return {"eval_loss": avg_loss}
|
| 322 |
+
|
| 323 |
+
def save_checkpoint(self, is_best: bool = False):
|
| 324 |
+
"""Save model checkpoint."""
|
| 325 |
+
checkpoint = {
|
| 326 |
+
"step": self.global_step,
|
| 327 |
+
"model_state_dict": self.model.state_dict(),
|
| 328 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 329 |
+
"scheduler_state_dict": self.scheduler.state_dict(),
|
| 330 |
+
"config": self.config,
|
| 331 |
+
"best_eval_loss": self.best_eval_loss,
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
if self.scaler:
|
| 335 |
+
checkpoint["scaler_state_dict"] = self.scaler.state_dict()
|
| 336 |
+
|
| 337 |
+
# Save latest
|
| 338 |
+
checkpoint_path = self.checkpoint_dir / f"checkpoint_{self.global_step:06d}.pt"
|
| 339 |
+
torch.save(checkpoint, checkpoint_path)
|
| 340 |
+
print(f"Saved checkpoint to {checkpoint_path}")
|
| 341 |
+
|
| 342 |
+
# Save best
|
| 343 |
+
if is_best:
|
| 344 |
+
best_path = self.checkpoint_dir / "best_model.pt"
|
| 345 |
+
torch.save(checkpoint, best_path)
|
| 346 |
+
print(f"Saved best model to {best_path}")
|
| 347 |
+
|
| 348 |
+
# Save latest link
|
| 349 |
+
latest_path = self.checkpoint_dir / "latest.pt"
|
| 350 |
+
torch.save(checkpoint, latest_path)
|
| 351 |
+
|
| 352 |
+
def load_checkpoint(self, checkpoint_path: str):
|
| 353 |
+
"""Load checkpoint."""
|
| 354 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
| 355 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
| 356 |
+
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 357 |
+
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
| 358 |
+
self.global_step = checkpoint["step"]
|
| 359 |
+
self.best_eval_loss = checkpoint.get("best_eval_loss", float('inf'))
|
| 360 |
+
|
| 361 |
+
if self.scaler and "scaler_state_dict" in checkpoint:
|
| 362 |
+
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
|
| 363 |
+
|
| 364 |
+
print(f"Loaded checkpoint from {checkpoint_path} at step {self.global_step}")
|
| 365 |
+
|
| 366 |
+
def _log_losses(self, losses: Dict[str, torch.Tensor], batch_idx: int):
|
| 367 |
+
"""Log losses to console and file."""
|
| 368 |
+
loss_str = " | ".join([f"{k}: {v.item():.4f}" for k, v in losses.items()])
|
| 369 |
+
print(f"Step {self.global_step} | {loss_str}")
|
| 370 |
+
|
| 371 |
+
def train(self):
|
| 372 |
+
"""Main training loop."""
|
| 373 |
+
print("Starting training...")
|
| 374 |
+
print(f"Total steps: {self.config['max_steps']}")
|
| 375 |
+
print(f"Device: {self.device}")
|
| 376 |
+
print(f"Batch size: {self.config['micro_batch_size']}")
|
| 377 |
+
print(f"Gradient accumulation steps: {self.config['gradient_accumulation_steps']}")
|
| 378 |
+
|
| 379 |
+
try:
|
| 380 |
+
self.train_epoch()
|
| 381 |
+
except KeyboardInterrupt:
|
| 382 |
+
print("Training interrupted")
|
| 383 |
+
finally:
|
| 384 |
+
self.save_checkpoint()
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def test_trainer():
|
| 388 |
+
"""Test trainer with small model."""
|
| 389 |
+
from models.vortex_model import VortexModel
|
| 390 |
+
from tokenizer.vortex_tokenizer import VortexScienceTokenizer
|
| 391 |
+
from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| 392 |
+
|
| 393 |
+
# Small config for testing
|
| 394 |
+
config = VORTEX_7B_CONFIG.copy()
|
| 395 |
+
config["d_model"] = 256
|
| 396 |
+
config["num_layers"] = 2
|
| 397 |
+
config["num_heads"] = 4
|
| 398 |
+
config["vocab_size"] = 1000
|
| 399 |
+
config["max_steps"] = 10
|
| 400 |
+
config["device"] = "cpu"
|
| 401 |
+
|
| 402 |
+
# Create model
|
| 403 |
+
model = VortexModel(config)
|
| 404 |
+
|
| 405 |
+
# Create dummy tokenizer
|
| 406 |
+
class DummyTokenizer:
|
| 407 |
+
def encode(self, text, add_special_tokens=True, return_tensors="pt"):
|
| 408 |
+
return {"input_ids": torch.randint(0, 1000, (1, 10)), "attention_mask": torch.ones(1, 10)}
|
| 409 |
+
|
| 410 |
+
tokenizer = DummyTokenizer()
|
| 411 |
+
|
| 412 |
+
# Create dummy dataset
|
| 413 |
+
class DummyDataset(torch.utils.data.Dataset):
|
| 414 |
+
def __len__(self): return 10
|
| 415 |
+
def __getitem__(self, idx):
|
| 416 |
+
return {
|
| 417 |
+
"input_ids": torch.randint(0, 1000, (32,)),
|
| 418 |
+
"attention_mask": torch.ones(32),
|
| 419 |
+
"labels": torch.randint(0, 1000, (32,)),
|
| 420 |
+
"domain": "physics",
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
train_dataset = DummyDataset()
|
| 424 |
+
eval_dataset = DummyDataset()
|
| 425 |
+
|
| 426 |
+
# Create trainer
|
| 427 |
+
trainer = VortexTrainer(
|
| 428 |
+
model=model,
|
| 429 |
+
tokenizer=tokenizer,
|
| 430 |
+
train_dataset=train_dataset,
|
| 431 |
+
config=config,
|
| 432 |
+
eval_dataset=eval_dataset,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# Run a few steps
|
| 436 |
+
trainer.train()
|
| 437 |
+
|
| 438 |
+
print("Trainer test passed!")
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
if __name__ == "__main__":
|
| 442 |
+
test_trainer()
|
vortex_config.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vortex-7B model configuration.
|
| 3 |
+
Optimized for 8GB VRAM (4060 laptop) and MacBook Pro M2/M3.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
VORTEX_7B_CONFIG = {
|
| 7 |
+
# Model dimensions
|
| 8 |
+
"d_model": 4096,
|
| 9 |
+
"num_layers": 32,
|
| 10 |
+
"num_heads": 32,
|
| 11 |
+
"head_dim": 128, # d_model // num_heads
|
| 12 |
+
|
| 13 |
+
# State-space layer parameters
|
| 14 |
+
"d_state": 16, # SSM state dimension
|
| 15 |
+
"d_conv": 4, # SSM convolution width
|
| 16 |
+
|
| 17 |
+
# Attention parameters
|
| 18 |
+
"window_size": 512, # Local attention window
|
| 19 |
+
"use_flash_attention": True, # CUDA only
|
| 20 |
+
|
| 21 |
+
# Feed-forward parameters
|
| 22 |
+
"ffn_expansion": 4, # Hidden dim = d_model * expansion
|
| 23 |
+
"num_domains": 7, # Physics, Math, Chemistry, Biology, Earth, Space, Zoology
|
| 24 |
+
|
| 25 |
+
# Tokenizer parameters
|
| 26 |
+
"vocab_size": 50000,
|
| 27 |
+
"max_seq_len": 16384,
|
| 28 |
+
|
| 29 |
+
# Layer ratio: 60% SSM, 40% attention
|
| 30 |
+
"ssm_ratio": 0.6,
|
| 31 |
+
|
| 32 |
+
# Data types
|
| 33 |
+
"dtype": "bfloat16",
|
| 34 |
+
|
| 35 |
+
# Special tokens
|
| 36 |
+
"special_tokens": {
|
| 37 |
+
"[PAD]": 0,
|
| 38 |
+
"[UNK]": 1,
|
| 39 |
+
"[BOS]": 2,
|
| 40 |
+
"[EOS]": 3,
|
| 41 |
+
"[EQUATION]": 4,
|
| 42 |
+
"[/EQUATION]": 5,
|
| 43 |
+
"[CITATION]": 6,
|
| 44 |
+
"[/CITATION]": 7,
|
| 45 |
+
"[MOLECULE]": 8,
|
| 46 |
+
"[/MOLECULE]": 9,
|
| 47 |
+
"[FIGURE]": 10,
|
| 48 |
+
"[TABLE]": 11,
|
| 49 |
+
"[MATH]": 12,
|
| 50 |
+
"[CHEM]": 13,
|
| 51 |
+
"[BIO]": 14,
|
| 52 |
+
"[PHYS]": 15,
|
| 53 |
+
"[EARTH]": 16,
|
| 54 |
+
"[SPACE]": 17,
|
| 55 |
+
"[ZOO]": 18,
|
| 56 |
+
},
|
| 57 |
+
|
| 58 |
+
# Domain tags
|
| 59 |
+
"domain_tags": ["[MATH]", "[CHEM]", "[BIO]", "[PHYS]", "[EARTH]", "[SPACE]", "[ZOO]"],
|
| 60 |
+
|
| 61 |
+
# Science module flags (enable/disable for ablation)
|
| 62 |
+
"enable_equation_module": True,
|
| 63 |
+
"enable_numerical_module": True,
|
| 64 |
+
"enable_citation_module": True,
|
| 65 |
+
"enable_molecular_module": True,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_config():
|
| 70 |
+
"""Return the 7B configuration dictionary."""
|
| 71 |
+
return VORTEX_7B_CONFIG
|