#!/usr/bin/env python3 """ Comprehensive unit tests for Vortex model components. Run with: python -m pytest test_model.py -v """ import pytest import torch import sys from pathlib import Path # Add Vortex to path sys.path.insert(0, str(Path(__file__).parent)) def test_tokenizer(): """Test VortexScienceTokenizer.""" from tokenizer.vortex_tokenizer import VortexScienceTokenizer from configs.vortex_7b_config import VORTEX_7B_CONFIG tokenizer = VortexScienceTokenizer(VORTEX_7B_CONFIG) # Test encoding/decoding text = "The equation is $E = mc^2$ and H2O is water." encoded = tokenizer.encode(text, return_tensors="pt") assert "input_ids" in encoded assert encoded["input_ids"].shape[0] == 1 # batch dim decoded = tokenizer.decode(encoded["input_ids"][0].tolist()) assert isinstance(decoded, str) print("✓ Tokenizer test passed") def test_ssm_layer(): """Test VortexSSM.""" from models.ssm_layer import VortexSSM batch_size = 2 seq_len = 64 d_model = 512 d_state = 16 ssm = VortexSSM(d_model, d_state=d_state) x = torch.randn(batch_size, seq_len, d_model) # Forward pass output = ssm(x) assert output.shape == x.shape # Stateful forward state = torch.zeros(batch_size, ssm.d_inner, d_state) output2, new_state = ssm(x, state=state, return_state=True) assert output2.shape == x.shape assert new_state.shape == (batch_size, ssm.d_inner, d_state) # Single step x_step = torch.randn(batch_size, d_model) output_step, state_step = ssm.step(x_step, state) assert output_step.shape == (batch_size, d_model) assert state_step.shape == (batch_size, ssm.d_inner, d_state) print("✓ SSM layer test passed") def test_attention_layer(): """Test VortexLocalAttention.""" from models.attention_layer import VortexLocalAttention batch_size = 2 seq_len = 128 d_model = 512 num_heads = 8 attn = VortexLocalAttention(d_model, num_heads, window_size=64, use_flash_attention=False) x = torch.randn(batch_size, seq_len, d_model) # Forward pass output = attn(x) assert output.shape == x.shape # With global mask global_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool) global_mask[0, 0] = True output2 = attn(x, global_mask=global_mask) assert output2.shape == x.shape print("✓ Local attention test passed") def test_scigate_ffn(): """Test SciGateFFN.""" from models.scigate_ffn import SciGateFFN batch_size = 2 seq_len = 64 d_model = 512 num_domains = 7 ffn = SciGateFFN(d_model, expansion=4, num_domains=num_domains) x = torch.randn(batch_size, seq_len, d_model) # Without domain info output = ffn(x) assert output.shape == x.shape # With domain IDs domain_ids = torch.randint(0, num_domains, (batch_size,)) output2 = ffn(x, domain_ids=domain_ids) assert output2.shape == x.shape # With domain tags domain_tags = torch.zeros(batch_size, seq_len, num_domains) domain_tags[:, :, 0] = 1.0 output3 = ffn(x, domain_tags=domain_tags) assert output3.shape == x.shape print("✓ SciGate FFN test passed") def test_equation_module(): """Test EquationModule.""" from models.science_modules.equation_module import EquationModule d_model = 512 batch_size = 2 seq_len = 64 module = EquationModule(d_model) x = torch.randn(batch_size, seq_len, d_model) text = ["E = mc^2 is famous.", "The integral $\\int x dx = x^2/2$."] output = module(x, text=text) assert output.shape == x.shape # Test equation loss equation_mask = torch.zeros(batch_size, seq_len) equation_mask[0, 5:10] = 1.0 loss = module.compute_equation_loss(x, equation_mask) assert loss.item() >= 0 print("✓ Equation module test passed") def test_numerical_module(): """Test NumericalReasoningModule.""" from models.science_modules.numerical_module import NumericalReasoningModule d_model = 512 batch_size = 2 seq_len = 64 module = NumericalReasoningModule(d_model) x = torch.randn(batch_size, seq_len, d_model) text = ["Speed of light: 2.998e8 m/s", "6.022e23 is Avogadro's number."] output = module(x, text=text) assert output.shape == x.shape print("✓ Numerical reasoning module test passed") def test_citation_module(): """Test CitationModule.""" from models.science_modules.citation_module import CitationModule d_model = 512 batch_size = 2 seq_len = 64 module = CitationModule(d_model) x = torch.randn(batch_size, seq_len, d_model) text = ["(Einstein, 1905) changed physics.", "See also [1, 2] for details."] output, confidence = module(x, text=text) assert output.shape == x.shape assert confidence.shape == (batch_size, seq_len, 1) # Test loss citation_mask = torch.zeros(batch_size, seq_len) citation_mask[0, 0:5] = 1.0 loss = module.compute_citation_loss(x, citation_mask, confidence) assert loss.item() >= 0 print("✓ Citation module test passed") def test_molecular_module(): """Test MolecularModule.""" from models.science_modules.molecular_module import MolecularModule d_model = 512 batch_size = 2 seq_len = 64 module = MolecularModule(d_model) x = torch.randn(batch_size, seq_len, d_model) text = ["H2O is water.", "DNA sequence: ACGTACGT"] output = module(x, text=text) assert output.shape == x.shape print("✓ Molecular module test passed") def test_vortex_model(): """Test full VortexModel.""" from models.vortex_model import VortexModel from configs.vortex_7b_config import VORTEX_7B_CONFIG # Small config for testing config = VORTEX_7B_CONFIG.copy() config["d_model"] = 256 config["num_layers"] = 4 config["num_heads"] = 4 config["vocab_size"] = 1000 model = VortexModel(config) batch_size = 2 seq_len = 32 input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)) # Forward pass output = model(input_ids) logits = output["logits"] assert logits.shape == (batch_size, seq_len, config["vocab_size"]) # Count parameters num_params = model.get_num_params() assert num_params > 0 print(f"✓ VortexModel test passed (params: {num_params:,})") def test_quality_filter(): """Test ScienceQualityFilter.""" from data.quality_filter import ScienceQualityFilter filter = ScienceQualityFilter() # Good text good_text = """ The experiment collected data from 100 participants. Results show a significant effect (p < 0.05). The equation E = mc^2 is fundamental. According to Smith et al., this confirms the hypothesis. """ assert filter.filter(good_text) # Bad: too short assert not filter.filter("Too short.") # Bad: unmatched equations bad_eq = "Equation $E = mc^2 and another $F = ma." assert not filter.filter(bad_eq) print("✓ Quality filter test passed") def test_domain_classifier(): """Test DomainClassifier.""" from data.domain_classifier import DomainClassifier d_model = 256 classifier = DomainClassifier(d_model) # Test with random hidden states batch_size = 4 seq_len = 32 hidden = torch.randn(batch_size, seq_len, d_model) logits = classifier(hidden) assert logits.shape == (batch_size, 7) # Test text classification text = "Quantum mechanics describes particle behavior." domain, conf = classifier.classify_text(text) assert domain in range(7) assert 0 <= conf <= 1 print("✓ Domain classifier test passed") def test_deduplication(): """Test MinHashLSH.""" from data.deduplication import MinHashLSH lsh = MinHashLSH(num_permutations=32, threshold=0.7, bands=4, rows_per_band=8) docs = [ ("doc1", "The quick brown fox jumps over the lazy dog."), ("doc2", "The quick brown fox jumps over the lazy dog!!!"), ("doc3", "Completely different text about science."), ] for doc_id, text in docs: lsh.add_document(doc_id, text) # Query similar results = lsh.query(docs[0][1]) # Should find doc2 as similar assert len(results) >= 1 assert any(r[0] == "doc2" for r in results) print("✓ Deduplication test passed") def test_losses(): """Test VortexLoss.""" from training.losses import VortexLoss config = {"loss_weights": { "lm_loss": 1.0, "equation_loss": 0.3, "domain_loss": 0.1, "citation_loss": 0.1, "numerical_loss": 0.2, }} loss_fn = VortexLoss(config) batch_size = 2 seq_len = 32 vocab_size = 1000 logits = torch.randn(batch_size, seq_len, vocab_size) labels = torch.randint(0, vocab_size, (batch_size, seq_len)) losses = loss_fn(logits, labels) assert "total_loss" in losses assert "lm_loss" in losses assert losses["total_loss"].item() > 0 print("✓ Losses test passed") def test_curriculum(): """Test CurriculumScheduler.""" from training.curriculum import CurriculumScheduler config = { "curriculum_stages": [ {"name": "foundation", "start": 0.0, "end": 0.2}, {"name": "domain", "start": 0.2, "end": 0.5}, {"name": "reasoning", "start": 0.5, "end": 0.8}, {"name": "integration", "start": 0.8, "end": 1.0}, ] } total_steps = 1000 scheduler = CurriculumScheduler(config, total_steps) # Test stage at different steps assert scheduler.get_stage_name(0) == "foundation" assert scheduler.get_stage_name(250) == "domain" assert scheduler.get_stage_name(500) == "reasoning" assert scheduler.get_stage_name(800) == "integration" # Test sampler weights = scheduler.get_dataset_sampler(100) assert isinstance(weights, dict) assert sum(weights.values()) == 1.0 print("✓ Curriculum test passed") def test_hf_integration(): """Test HuggingFace integration.""" from configuration_vortex import VortexConfig from modeling_vortex import VortexForCausalLM from tokenization_vortex import VortexTokenizer # Config config = VortexConfig( d_model=128, num_layers=2, num_heads=4, vocab_size=100, ) # Model model = VortexForCausalLM(config) batch_size = 2 seq_len = 16 input_ids = torch.randint(0, 100, (batch_size, seq_len)) outputs = model(input_ids) assert outputs.logits.shape == (batch_size, seq_len, 100) # Save and load model.save_pretrained("./test_hf_model") config.save_pretrained("./test_hf_model") from transformers import AutoConfig, AutoModelForCausalLM loaded_config = AutoConfig.from_pretrained("./test_hf_model") loaded_model = AutoModelForCausalLM.from_pretrained("./test_hf_model") assert loaded_config.model_type == "vortex" assert isinstance(loaded_model, VortexForCausalLM) # Cleanup import shutil shutil.rmtree("./test_hf_model") print("✓ HuggingFace integration test passed") def run_all_tests(): """Run all tests.""" tests = [ test_tokenizer, test_ssm_layer, test_attention_layer, test_scigate_ffn, test_equation_module, test_numerical_module, test_citation_module, test_molecular_module, test_vortex_model, test_quality_filter, test_domain_classifier, test_deduplication, test_losses, test_curriculum, test_hf_integration, ] print("Running Vortex unit tests...\n") passed = 0 failed = 0 for test in tests: try: test() passed += 1 except Exception as e: print(f"✗ {test.__name__} failed: {e}") failed += 1 import traceback traceback.print_exc() print(f"\n{'='*50}") print(f"Tests: {passed + failed} total, {passed} passed, {failed} failed") print(f"{'='*50}") return failed == 0 if __name__ == "__main__": success = run_all_tests() sys.exit(0 if success else 1)