Vortex-7b-V1 / test_model.py
Zandy-Wandy's picture
Upload Vortex model
bf64b03 verified
#!/usr/bin/env python3
"""
Comprehensive unit tests for Vortex model components.
Run with: python -m pytest test_model.py -v
"""
import pytest
import torch
import sys
from pathlib import Path
# Add Vortex to path
sys.path.insert(0, str(Path(__file__).parent))
def test_tokenizer():
"""Test VortexScienceTokenizer."""
from tokenizer.vortex_tokenizer import VortexScienceTokenizer
from configs.vortex_7b_config import VORTEX_7B_CONFIG
tokenizer = VortexScienceTokenizer(VORTEX_7B_CONFIG)
# Test encoding/decoding
text = "The equation is $E = mc^2$ and H2O is water."
encoded = tokenizer.encode(text, return_tensors="pt")
assert "input_ids" in encoded
assert encoded["input_ids"].shape[0] == 1 # batch dim
decoded = tokenizer.decode(encoded["input_ids"][0].tolist())
assert isinstance(decoded, str)
print("✓ Tokenizer test passed")
def test_ssm_layer():
"""Test VortexSSM."""
from models.ssm_layer import VortexSSM
batch_size = 2
seq_len = 64
d_model = 512
d_state = 16
ssm = VortexSSM(d_model, d_state=d_state)
x = torch.randn(batch_size, seq_len, d_model)
# Forward pass
output = ssm(x)
assert output.shape == x.shape
# Stateful forward
state = torch.zeros(batch_size, ssm.d_inner, d_state)
output2, new_state = ssm(x, state=state, return_state=True)
assert output2.shape == x.shape
assert new_state.shape == (batch_size, ssm.d_inner, d_state)
# Single step
x_step = torch.randn(batch_size, d_model)
output_step, state_step = ssm.step(x_step, state)
assert output_step.shape == (batch_size, d_model)
assert state_step.shape == (batch_size, ssm.d_inner, d_state)
print("✓ SSM layer test passed")
def test_attention_layer():
"""Test VortexLocalAttention."""
from models.attention_layer import VortexLocalAttention
batch_size = 2
seq_len = 128
d_model = 512
num_heads = 8
attn = VortexLocalAttention(d_model, num_heads, window_size=64, use_flash_attention=False)
x = torch.randn(batch_size, seq_len, d_model)
# Forward pass
output = attn(x)
assert output.shape == x.shape
# With global mask
global_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
global_mask[0, 0] = True
output2 = attn(x, global_mask=global_mask)
assert output2.shape == x.shape
print("✓ Local attention test passed")
def test_scigate_ffn():
"""Test SciGateFFN."""
from models.scigate_ffn import SciGateFFN
batch_size = 2
seq_len = 64
d_model = 512
num_domains = 7
ffn = SciGateFFN(d_model, expansion=4, num_domains=num_domains)
x = torch.randn(batch_size, seq_len, d_model)
# Without domain info
output = ffn(x)
assert output.shape == x.shape
# With domain IDs
domain_ids = torch.randint(0, num_domains, (batch_size,))
output2 = ffn(x, domain_ids=domain_ids)
assert output2.shape == x.shape
# With domain tags
domain_tags = torch.zeros(batch_size, seq_len, num_domains)
domain_tags[:, :, 0] = 1.0
output3 = ffn(x, domain_tags=domain_tags)
assert output3.shape == x.shape
print("✓ SciGate FFN test passed")
def test_equation_module():
"""Test EquationModule."""
from models.science_modules.equation_module import EquationModule
d_model = 512
batch_size = 2
seq_len = 64
module = EquationModule(d_model)
x = torch.randn(batch_size, seq_len, d_model)
text = ["E = mc^2 is famous.", "The integral $\\int x dx = x^2/2$."]
output = module(x, text=text)
assert output.shape == x.shape
# Test equation loss
equation_mask = torch.zeros(batch_size, seq_len)
equation_mask[0, 5:10] = 1.0
loss = module.compute_equation_loss(x, equation_mask)
assert loss.item() >= 0
print("✓ Equation module test passed")
def test_numerical_module():
"""Test NumericalReasoningModule."""
from models.science_modules.numerical_module import NumericalReasoningModule
d_model = 512
batch_size = 2
seq_len = 64
module = NumericalReasoningModule(d_model)
x = torch.randn(batch_size, seq_len, d_model)
text = ["Speed of light: 2.998e8 m/s", "6.022e23 is Avogadro's number."]
output = module(x, text=text)
assert output.shape == x.shape
print("✓ Numerical reasoning module test passed")
def test_citation_module():
"""Test CitationModule."""
from models.science_modules.citation_module import CitationModule
d_model = 512
batch_size = 2
seq_len = 64
module = CitationModule(d_model)
x = torch.randn(batch_size, seq_len, d_model)
text = ["(Einstein, 1905) changed physics.", "See also [1, 2] for details."]
output, confidence = module(x, text=text)
assert output.shape == x.shape
assert confidence.shape == (batch_size, seq_len, 1)
# Test loss
citation_mask = torch.zeros(batch_size, seq_len)
citation_mask[0, 0:5] = 1.0
loss = module.compute_citation_loss(x, citation_mask, confidence)
assert loss.item() >= 0
print("✓ Citation module test passed")
def test_molecular_module():
"""Test MolecularModule."""
from models.science_modules.molecular_module import MolecularModule
d_model = 512
batch_size = 2
seq_len = 64
module = MolecularModule(d_model)
x = torch.randn(batch_size, seq_len, d_model)
text = ["H2O is water.", "DNA sequence: ACGTACGT"]
output = module(x, text=text)
assert output.shape == x.shape
print("✓ Molecular module test passed")
def test_vortex_model():
"""Test full VortexModel."""
from models.vortex_model import VortexModel
from configs.vortex_7b_config import VORTEX_7B_CONFIG
# Small config for testing
config = VORTEX_7B_CONFIG.copy()
config["d_model"] = 256
config["num_layers"] = 4
config["num_heads"] = 4
config["vocab_size"] = 1000
model = VortexModel(config)
batch_size = 2
seq_len = 32
input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
# Forward pass
output = model(input_ids)
logits = output["logits"]
assert logits.shape == (batch_size, seq_len, config["vocab_size"])
# Count parameters
num_params = model.get_num_params()
assert num_params > 0
print(f"✓ VortexModel test passed (params: {num_params:,})")
def test_quality_filter():
"""Test ScienceQualityFilter."""
from data.quality_filter import ScienceQualityFilter
filter = ScienceQualityFilter()
# Good text
good_text = """
The experiment collected data from 100 participants. Results show a
significant effect (p < 0.05). The equation E = mc^2 is fundamental.
According to Smith et al., this confirms the hypothesis.
"""
assert filter.filter(good_text)
# Bad: too short
assert not filter.filter("Too short.")
# Bad: unmatched equations
bad_eq = "Equation $E = mc^2 and another $F = ma."
assert not filter.filter(bad_eq)
print("✓ Quality filter test passed")
def test_domain_classifier():
"""Test DomainClassifier."""
from data.domain_classifier import DomainClassifier
d_model = 256
classifier = DomainClassifier(d_model)
# Test with random hidden states
batch_size = 4
seq_len = 32
hidden = torch.randn(batch_size, seq_len, d_model)
logits = classifier(hidden)
assert logits.shape == (batch_size, 7)
# Test text classification
text = "Quantum mechanics describes particle behavior."
domain, conf = classifier.classify_text(text)
assert domain in range(7)
assert 0 <= conf <= 1
print("✓ Domain classifier test passed")
def test_deduplication():
"""Test MinHashLSH."""
from data.deduplication import MinHashLSH
lsh = MinHashLSH(num_permutations=32, threshold=0.7, bands=4, rows_per_band=8)
docs = [
("doc1", "The quick brown fox jumps over the lazy dog."),
("doc2", "The quick brown fox jumps over the lazy dog!!!"),
("doc3", "Completely different text about science."),
]
for doc_id, text in docs:
lsh.add_document(doc_id, text)
# Query similar
results = lsh.query(docs[0][1])
# Should find doc2 as similar
assert len(results) >= 1
assert any(r[0] == "doc2" for r in results)
print("✓ Deduplication test passed")
def test_losses():
"""Test VortexLoss."""
from training.losses import VortexLoss
config = {"loss_weights": {
"lm_loss": 1.0,
"equation_loss": 0.3,
"domain_loss": 0.1,
"citation_loss": 0.1,
"numerical_loss": 0.2,
}}
loss_fn = VortexLoss(config)
batch_size = 2
seq_len = 32
vocab_size = 1000
logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))
losses = loss_fn(logits, labels)
assert "total_loss" in losses
assert "lm_loss" in losses
assert losses["total_loss"].item() > 0
print("✓ Losses test passed")
def test_curriculum():
"""Test CurriculumScheduler."""
from training.curriculum import CurriculumScheduler
config = {
"curriculum_stages": [
{"name": "foundation", "start": 0.0, "end": 0.2},
{"name": "domain", "start": 0.2, "end": 0.5},
{"name": "reasoning", "start": 0.5, "end": 0.8},
{"name": "integration", "start": 0.8, "end": 1.0},
]
}
total_steps = 1000
scheduler = CurriculumScheduler(config, total_steps)
# Test stage at different steps
assert scheduler.get_stage_name(0) == "foundation"
assert scheduler.get_stage_name(250) == "domain"
assert scheduler.get_stage_name(500) == "reasoning"
assert scheduler.get_stage_name(800) == "integration"
# Test sampler
weights = scheduler.get_dataset_sampler(100)
assert isinstance(weights, dict)
assert sum(weights.values()) == 1.0
print("✓ Curriculum test passed")
def test_hf_integration():
"""Test HuggingFace integration."""
from configuration_vortex import VortexConfig
from modeling_vortex import VortexForCausalLM
from tokenization_vortex import VortexTokenizer
# Config
config = VortexConfig(
d_model=128,
num_layers=2,
num_heads=4,
vocab_size=100,
)
# Model
model = VortexForCausalLM(config)
batch_size = 2
seq_len = 16
input_ids = torch.randint(0, 100, (batch_size, seq_len))
outputs = model(input_ids)
assert outputs.logits.shape == (batch_size, seq_len, 100)
# Save and load
model.save_pretrained("./test_hf_model")
config.save_pretrained("./test_hf_model")
from transformers import AutoConfig, AutoModelForCausalLM
loaded_config = AutoConfig.from_pretrained("./test_hf_model")
loaded_model = AutoModelForCausalLM.from_pretrained("./test_hf_model")
assert loaded_config.model_type == "vortex"
assert isinstance(loaded_model, VortexForCausalLM)
# Cleanup
import shutil
shutil.rmtree("./test_hf_model")
print("✓ HuggingFace integration test passed")
def run_all_tests():
"""Run all tests."""
tests = [
test_tokenizer,
test_ssm_layer,
test_attention_layer,
test_scigate_ffn,
test_equation_module,
test_numerical_module,
test_citation_module,
test_molecular_module,
test_vortex_model,
test_quality_filter,
test_domain_classifier,
test_deduplication,
test_losses,
test_curriculum,
test_hf_integration,
]
print("Running Vortex unit tests...\n")
passed = 0
failed = 0
for test in tests:
try:
test()
passed += 1
except Exception as e:
print(f"✗ {test.__name__} failed: {e}")
failed += 1
import traceback
traceback.print_exc()
print(f"\n{'='*50}")
print(f"Tests: {passed + failed} total, {passed} passed, {failed} failed")
print(f"{'='*50}")
return failed == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)