| |
|
| | """
|
| | Comprehensive unit tests for Vortex model components.
|
| | Run with: python -m pytest test_model.py -v
|
| | """
|
| |
|
| | import pytest
|
| | import torch
|
| | import sys
|
| | from pathlib import Path
|
| |
|
| |
|
| | sys.path.insert(0, str(Path(__file__).parent))
|
| |
|
| |
|
| | def test_tokenizer():
|
| | """Test VortexScienceTokenizer."""
|
| | from tokenizer.vortex_tokenizer import VortexScienceTokenizer
|
| | from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| |
|
| | tokenizer = VortexScienceTokenizer(VORTEX_7B_CONFIG)
|
| |
|
| |
|
| | text = "The equation is $E = mc^2$ and H2O is water."
|
| | encoded = tokenizer.encode(text, return_tensors="pt")
|
| | assert "input_ids" in encoded
|
| | assert encoded["input_ids"].shape[0] == 1
|
| |
|
| | decoded = tokenizer.decode(encoded["input_ids"][0].tolist())
|
| | assert isinstance(decoded, str)
|
| | print("✓ Tokenizer test passed")
|
| |
|
| |
|
| | def test_ssm_layer():
|
| | """Test VortexSSM."""
|
| | from models.ssm_layer import VortexSSM
|
| |
|
| | batch_size = 2
|
| | seq_len = 64
|
| | d_model = 512
|
| | d_state = 16
|
| |
|
| | ssm = VortexSSM(d_model, d_state=d_state)
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| |
|
| |
|
| | output = ssm(x)
|
| | assert output.shape == x.shape
|
| |
|
| |
|
| | state = torch.zeros(batch_size, ssm.d_inner, d_state)
|
| | output2, new_state = ssm(x, state=state, return_state=True)
|
| | assert output2.shape == x.shape
|
| | assert new_state.shape == (batch_size, ssm.d_inner, d_state)
|
| |
|
| |
|
| | x_step = torch.randn(batch_size, d_model)
|
| | output_step, state_step = ssm.step(x_step, state)
|
| | assert output_step.shape == (batch_size, d_model)
|
| | assert state_step.shape == (batch_size, ssm.d_inner, d_state)
|
| |
|
| | print("✓ SSM layer test passed")
|
| |
|
| |
|
| | def test_attention_layer():
|
| | """Test VortexLocalAttention."""
|
| | from models.attention_layer import VortexLocalAttention
|
| |
|
| | batch_size = 2
|
| | seq_len = 128
|
| | d_model = 512
|
| | num_heads = 8
|
| |
|
| | attn = VortexLocalAttention(d_model, num_heads, window_size=64, use_flash_attention=False)
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| |
|
| |
|
| | output = attn(x)
|
| | assert output.shape == x.shape
|
| |
|
| |
|
| | global_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
|
| | global_mask[0, 0] = True
|
| | output2 = attn(x, global_mask=global_mask)
|
| | assert output2.shape == x.shape
|
| |
|
| | print("✓ Local attention test passed")
|
| |
|
| |
|
| | def test_scigate_ffn():
|
| | """Test SciGateFFN."""
|
| | from models.scigate_ffn import SciGateFFN
|
| |
|
| | batch_size = 2
|
| | seq_len = 64
|
| | d_model = 512
|
| | num_domains = 7
|
| |
|
| | ffn = SciGateFFN(d_model, expansion=4, num_domains=num_domains)
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| |
|
| |
|
| | output = ffn(x)
|
| | assert output.shape == x.shape
|
| |
|
| |
|
| | domain_ids = torch.randint(0, num_domains, (batch_size,))
|
| | output2 = ffn(x, domain_ids=domain_ids)
|
| | assert output2.shape == x.shape
|
| |
|
| |
|
| | domain_tags = torch.zeros(batch_size, seq_len, num_domains)
|
| | domain_tags[:, :, 0] = 1.0
|
| | output3 = ffn(x, domain_tags=domain_tags)
|
| | assert output3.shape == x.shape
|
| |
|
| | print("✓ SciGate FFN test passed")
|
| |
|
| |
|
| | def test_equation_module():
|
| | """Test EquationModule."""
|
| | from models.science_modules.equation_module import EquationModule
|
| |
|
| | d_model = 512
|
| | batch_size = 2
|
| | seq_len = 64
|
| |
|
| | module = EquationModule(d_model)
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| | text = ["E = mc^2 is famous.", "The integral $\\int x dx = x^2/2$."]
|
| |
|
| | output = module(x, text=text)
|
| | assert output.shape == x.shape
|
| |
|
| |
|
| | equation_mask = torch.zeros(batch_size, seq_len)
|
| | equation_mask[0, 5:10] = 1.0
|
| | loss = module.compute_equation_loss(x, equation_mask)
|
| | assert loss.item() >= 0
|
| |
|
| | print("✓ Equation module test passed")
|
| |
|
| |
|
| | def test_numerical_module():
|
| | """Test NumericalReasoningModule."""
|
| | from models.science_modules.numerical_module import NumericalReasoningModule
|
| |
|
| | d_model = 512
|
| | batch_size = 2
|
| | seq_len = 64
|
| |
|
| | module = NumericalReasoningModule(d_model)
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| | text = ["Speed of light: 2.998e8 m/s", "6.022e23 is Avogadro's number."]
|
| |
|
| | output = module(x, text=text)
|
| | assert output.shape == x.shape
|
| |
|
| | print("✓ Numerical reasoning module test passed")
|
| |
|
| |
|
| | def test_citation_module():
|
| | """Test CitationModule."""
|
| | from models.science_modules.citation_module import CitationModule
|
| |
|
| | d_model = 512
|
| | batch_size = 2
|
| | seq_len = 64
|
| |
|
| | module = CitationModule(d_model)
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| | text = ["(Einstein, 1905) changed physics.", "See also [1, 2] for details."]
|
| |
|
| | output, confidence = module(x, text=text)
|
| | assert output.shape == x.shape
|
| | assert confidence.shape == (batch_size, seq_len, 1)
|
| |
|
| |
|
| | citation_mask = torch.zeros(batch_size, seq_len)
|
| | citation_mask[0, 0:5] = 1.0
|
| | loss = module.compute_citation_loss(x, citation_mask, confidence)
|
| | assert loss.item() >= 0
|
| |
|
| | print("✓ Citation module test passed")
|
| |
|
| |
|
| | def test_molecular_module():
|
| | """Test MolecularModule."""
|
| | from models.science_modules.molecular_module import MolecularModule
|
| |
|
| | d_model = 512
|
| | batch_size = 2
|
| | seq_len = 64
|
| |
|
| | module = MolecularModule(d_model)
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| | text = ["H2O is water.", "DNA sequence: ACGTACGT"]
|
| |
|
| | output = module(x, text=text)
|
| | assert output.shape == x.shape
|
| |
|
| | print("✓ Molecular module test passed")
|
| |
|
| |
|
| | def test_vortex_model():
|
| | """Test full VortexModel."""
|
| | from models.vortex_model import VortexModel
|
| | from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| |
|
| |
|
| | config = VORTEX_7B_CONFIG.copy()
|
| | config["d_model"] = 256
|
| | config["num_layers"] = 4
|
| | config["num_heads"] = 4
|
| | config["vocab_size"] = 1000
|
| |
|
| | model = VortexModel(config)
|
| |
|
| | batch_size = 2
|
| | seq_len = 32
|
| | input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
|
| |
|
| |
|
| | output = model(input_ids)
|
| | logits = output["logits"]
|
| | assert logits.shape == (batch_size, seq_len, config["vocab_size"])
|
| |
|
| |
|
| | num_params = model.get_num_params()
|
| | assert num_params > 0
|
| |
|
| | print(f"✓ VortexModel test passed (params: {num_params:,})")
|
| |
|
| |
|
| | def test_quality_filter():
|
| | """Test ScienceQualityFilter."""
|
| | from data.quality_filter import ScienceQualityFilter
|
| |
|
| | filter = ScienceQualityFilter()
|
| |
|
| |
|
| | good_text = """
|
| | The experiment collected data from 100 participants. Results show a
|
| | significant effect (p < 0.05). The equation E = mc^2 is fundamental.
|
| | According to Smith et al., this confirms the hypothesis.
|
| | """
|
| | assert filter.filter(good_text)
|
| |
|
| |
|
| | assert not filter.filter("Too short.")
|
| |
|
| |
|
| | bad_eq = "Equation $E = mc^2 and another $F = ma."
|
| | assert not filter.filter(bad_eq)
|
| |
|
| | print("✓ Quality filter test passed")
|
| |
|
| |
|
| | def test_domain_classifier():
|
| | """Test DomainClassifier."""
|
| | from data.domain_classifier import DomainClassifier
|
| |
|
| | d_model = 256
|
| | classifier = DomainClassifier(d_model)
|
| |
|
| |
|
| | batch_size = 4
|
| | seq_len = 32
|
| | hidden = torch.randn(batch_size, seq_len, d_model)
|
| | logits = classifier(hidden)
|
| | assert logits.shape == (batch_size, 7)
|
| |
|
| |
|
| | text = "Quantum mechanics describes particle behavior."
|
| | domain, conf = classifier.classify_text(text)
|
| | assert domain in range(7)
|
| | assert 0 <= conf <= 1
|
| |
|
| | print("✓ Domain classifier test passed")
|
| |
|
| |
|
| | def test_deduplication():
|
| | """Test MinHashLSH."""
|
| | from data.deduplication import MinHashLSH
|
| |
|
| | lsh = MinHashLSH(num_permutations=32, threshold=0.7, bands=4, rows_per_band=8)
|
| |
|
| | docs = [
|
| | ("doc1", "The quick brown fox jumps over the lazy dog."),
|
| | ("doc2", "The quick brown fox jumps over the lazy dog!!!"),
|
| | ("doc3", "Completely different text about science."),
|
| | ]
|
| |
|
| | for doc_id, text in docs:
|
| | lsh.add_document(doc_id, text)
|
| |
|
| |
|
| | results = lsh.query(docs[0][1])
|
| |
|
| | assert len(results) >= 1
|
| | assert any(r[0] == "doc2" for r in results)
|
| |
|
| | print("✓ Deduplication test passed")
|
| |
|
| |
|
| | def test_losses():
|
| | """Test VortexLoss."""
|
| | from training.losses import VortexLoss
|
| |
|
| | config = {"loss_weights": {
|
| | "lm_loss": 1.0,
|
| | "equation_loss": 0.3,
|
| | "domain_loss": 0.1,
|
| | "citation_loss": 0.1,
|
| | "numerical_loss": 0.2,
|
| | }}
|
| |
|
| | loss_fn = VortexLoss(config)
|
| |
|
| | batch_size = 2
|
| | seq_len = 32
|
| | vocab_size = 1000
|
| |
|
| | logits = torch.randn(batch_size, seq_len, vocab_size)
|
| | labels = torch.randint(0, vocab_size, (batch_size, seq_len))
|
| |
|
| | losses = loss_fn(logits, labels)
|
| | assert "total_loss" in losses
|
| | assert "lm_loss" in losses
|
| | assert losses["total_loss"].item() > 0
|
| |
|
| | print("✓ Losses test passed")
|
| |
|
| |
|
| | def test_curriculum():
|
| | """Test CurriculumScheduler."""
|
| | from training.curriculum import CurriculumScheduler
|
| |
|
| | config = {
|
| | "curriculum_stages": [
|
| | {"name": "foundation", "start": 0.0, "end": 0.2},
|
| | {"name": "domain", "start": 0.2, "end": 0.5},
|
| | {"name": "reasoning", "start": 0.5, "end": 0.8},
|
| | {"name": "integration", "start": 0.8, "end": 1.0},
|
| | ]
|
| | }
|
| |
|
| | total_steps = 1000
|
| | scheduler = CurriculumScheduler(config, total_steps)
|
| |
|
| |
|
| | assert scheduler.get_stage_name(0) == "foundation"
|
| | assert scheduler.get_stage_name(250) == "domain"
|
| | assert scheduler.get_stage_name(500) == "reasoning"
|
| | assert scheduler.get_stage_name(800) == "integration"
|
| |
|
| |
|
| | weights = scheduler.get_dataset_sampler(100)
|
| | assert isinstance(weights, dict)
|
| | assert sum(weights.values()) == 1.0
|
| |
|
| | print("✓ Curriculum test passed")
|
| |
|
| |
|
| | def test_hf_integration():
|
| | """Test HuggingFace integration."""
|
| | from configuration_vortex import VortexConfig
|
| | from modeling_vortex import VortexForCausalLM
|
| | from tokenization_vortex import VortexTokenizer
|
| |
|
| |
|
| | config = VortexConfig(
|
| | d_model=128,
|
| | num_layers=2,
|
| | num_heads=4,
|
| | vocab_size=100,
|
| | )
|
| |
|
| |
|
| | model = VortexForCausalLM(config)
|
| | batch_size = 2
|
| | seq_len = 16
|
| | input_ids = torch.randint(0, 100, (batch_size, seq_len))
|
| |
|
| | outputs = model(input_ids)
|
| | assert outputs.logits.shape == (batch_size, seq_len, 100)
|
| |
|
| |
|
| | model.save_pretrained("./test_hf_model")
|
| | config.save_pretrained("./test_hf_model")
|
| |
|
| | from transformers import AutoConfig, AutoModelForCausalLM
|
| | loaded_config = AutoConfig.from_pretrained("./test_hf_model")
|
| | loaded_model = AutoModelForCausalLM.from_pretrained("./test_hf_model")
|
| |
|
| | assert loaded_config.model_type == "vortex"
|
| | assert isinstance(loaded_model, VortexForCausalLM)
|
| |
|
| |
|
| | import shutil
|
| | shutil.rmtree("./test_hf_model")
|
| |
|
| | print("✓ HuggingFace integration test passed")
|
| |
|
| |
|
| | def run_all_tests():
|
| | """Run all tests."""
|
| | tests = [
|
| | test_tokenizer,
|
| | test_ssm_layer,
|
| | test_attention_layer,
|
| | test_scigate_ffn,
|
| | test_equation_module,
|
| | test_numerical_module,
|
| | test_citation_module,
|
| | test_molecular_module,
|
| | test_vortex_model,
|
| | test_quality_filter,
|
| | test_domain_classifier,
|
| | test_deduplication,
|
| | test_losses,
|
| | test_curriculum,
|
| | test_hf_integration,
|
| | ]
|
| |
|
| | print("Running Vortex unit tests...\n")
|
| | passed = 0
|
| | failed = 0
|
| |
|
| | for test in tests:
|
| | try:
|
| | test()
|
| | passed += 1
|
| | except Exception as e:
|
| | print(f"✗ {test.__name__} failed: {e}")
|
| | failed += 1
|
| | import traceback
|
| | traceback.print_exc()
|
| |
|
| | print(f"\n{'='*50}")
|
| | print(f"Tests: {passed + failed} total, {passed} passed, {failed} failed")
|
| | print(f"{'='*50}")
|
| |
|
| | return failed == 0
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | success = run_all_tests()
|
| | sys.exit(0 if success else 1)
|
| |
|