fin-moe-latent-encoder / test_model.py
nahommohan's picture
Upload test_model.py
0232e6f verified
"""
End-to-end test for the Fin-MoE Latent Encoder.
Tests:
1. Model instantiation and parameter counting
2. Forward pass with synthetic data
3. Loss computation and backward pass
4. Gradient flow through all components
5. Individual component tests (MLOFI, Active Depth, Sentiment, FMHCA, MoE)
6. Inference mode (get_latent)
"""
import torch
import sys
import traceback
def test_model_instantiation():
"""Test model creates with default config."""
print("=" * 60)
print("TEST 1: Model Instantiation")
print("=" * 60)
from fin_moe.model import FinMoELatentEncoder
model = FinMoELatentEncoder()
params = model.count_parameters()
print(f"Model created successfully!")
print(f"\nParameter counts:")
for component, count in params.items():
print(f" {component:20s}: {count:>10,d}")
total = params["total"]
print(f"\n Total: {total:,d} ({total / 1e6:.1f}M parameters)")
print(f" Config: {model.config}")
return model
def test_forward_pass(model):
"""Test full forward pass with synthetic data."""
print("\n" + "=" * 60)
print("TEST 2: Forward Pass")
print("=" * 60)
from fin_moe.data_pipeline import generate_synthetic_batch
batch = generate_synthetic_batch(batch_size=4, micro_seq_len=50, macro_seq_len=16)
print(f"Input shapes:")
for k, v in batch.items():
if isinstance(v, torch.Tensor):
print(f" {k:20s}: {list(v.shape)}")
# Forward pass
model.train()
outputs = model(**batch)
print(f"\nOutput shapes:")
for k, v in outputs.items():
if isinstance(v, torch.Tensor):
print(f" {k:20s}: {list(v.shape)} | range: [{v.min().item():.4f}, {v.max().item():.4f}]")
elif isinstance(v, dict):
print(f" {k:20s}: dict with {len(v)} keys")
print(f"\n Loss: {outputs['loss'].item():.4f}")
print(f" Loss components:")
for k, v in outputs['loss_dict'].items():
print(f" {k:25s}: {v:.4f}")
return outputs
def test_backward_pass(model, outputs):
"""Test gradient computation and flow."""
print("\n" + "=" * 60)
print("TEST 3: Backward Pass & Gradient Flow")
print("=" * 60)
loss = outputs["loss"]
loss.backward()
# Check gradients exist for all components
components = {
"micro_encoder": model.micro_encoder,
"macro_encoder": model.macro_encoder,
"fmhca": model.fmhca,
"moe": model.moe,
"heads": model.heads,
"loss_fn": model.loss_fn,
}
all_good = True
for name, component in components.items():
has_grad = any(
p.grad is not None and p.grad.abs().sum() > 0
for p in component.parameters() if p.requires_grad
)
total_params = sum(1 for p in component.parameters() if p.requires_grad)
grad_params = sum(
1 for p in component.parameters()
if p.requires_grad and p.grad is not None and p.grad.abs().sum() > 0
)
status = "βœ“" if has_grad else "βœ—"
if not has_grad:
all_good = False
print(f" {status} {name:20s}: {grad_params}/{total_params} params have gradients")
# Check specific learnable parameters
print(f"\n Key learnable parameters:")
print(f" FMHCA alpha (tanh gate): {model.fmhca.micro_to_macro.alpha.data.item():.4f} "
f"(grad: {model.fmhca.micro_to_macro.alpha.grad.item():.6f})")
print(f" Loss log_var_mse: {model.loss_fn.log_var_mse.data.item():.4f} "
f"(grad: {model.loss_fn.log_var_mse.grad.item():.6f})")
print(f" Loss log_var_direction: {model.loss_fn.log_var_direction.data.item():.4f} "
f"(grad: {model.loss_fn.log_var_direction.grad.item():.6f})")
print(f" Loss log_var_toxicity: {model.loss_fn.log_var_toxicity.data.item():.4f} "
f"(grad: {model.loss_fn.log_var_toxicity.grad.item():.6f})")
model.zero_grad()
return all_good
def test_mlofi_extractor():
"""Test MLOFI feature extraction independently."""
print("\n" + "=" * 60)
print("TEST 4: MLOFI Feature Extractor")
print("=" * 60)
from fin_moe.feature_extractors import MLOFIExtractor
extractor = MLOFIExtractor(n_levels=10, d_model=128)
bid_vol = torch.abs(torch.randn(2, 50, 10)) * 100
ask_vol = torch.abs(torch.randn(2, 50, 10)) * 100
features = extractor(bid_vol, ask_vol)
print(f" Input: bid_vol {list(bid_vol.shape)}, ask_vol {list(ask_vol.shape)}")
print(f" Output: {list(features.shape)}")
print(f" Depth weights (learned): {extractor.compute_depth_weights().detach().numpy().round(3)}")
# Verify OFI computation
ofi = extractor.compute_per_level_ofi(bid_vol, ask_vol)
print(f" OFI shape: {list(ofi.shape)}, mean: {ofi.mean():.4f}, std: {ofi.std():.4f}")
def test_active_depth():
"""Test Active Depth feature extraction."""
print("\n" + "=" * 60)
print("TEST 5: Active Depth Features")
print("=" * 60)
from fin_moe.feature_extractors import ActiveDepthExtractor
extractor = ActiveDepthExtractor(n_levels=10, d_model=128)
bid_q = torch.abs(torch.randn(2, 50, 10)) * 100
ask_q = torch.abs(torch.randn(2, 50, 10)) * 100
features = extractor(bid_q, ask_q)
print(f" Input: bid {list(bid_q.shape)}, ask {list(ask_q.shape)}")
print(f" Output: {list(features.shape)}")
# Test individual metrics
ke = extractor.compute_kinetic_energy(bid_q + ask_q)
mt = extractor.compute_market_temperature(bid_q + ask_q)
print(f" Kinetic Energy: {list(ke.shape)}, mean: {ke.mean():.2f}")
print(f" Market Temperature: {list(mt.shape)}, mean: {mt.mean():.4f}")
def test_sentiment_tokenizer():
"""Test Economic Sentiment Tokenizer."""
print("\n" + "=" * 60)
print("TEST 6: Economic Sentiment Tokenizer")
print("=" * 60)
from fin_moe.feature_extractors import EconomicSentimentTokenizer
tokenizer = EconomicSentimentTokenizer(text_dim=768, d_model=128, n_tokens=4)
text_emb = torch.randn(2, 32, 768)
tokens, scores = tokenizer(text_emb)
print(f" Input: {list(text_emb.shape)}")
print(f" Output tokens: {list(tokens.shape)}")
print(f" Sentiment dimensions:")
for name, score in scores.items():
print(f" {name:25s}: {score.detach().numpy().round(3)}")
def test_cross_attention():
"""Test FMHCA fusion module."""
print("\n" + "=" * 60)
print("TEST 7: Financial Multi-Head Cross-Attention")
print("=" * 60)
from fin_moe.cross_attention import FinancialMultiHeadCrossAttention
fmhca = FinancialMultiHeadCrossAttention(
d_model=128, n_heads=8, n_fusion_layers=2, n_bottleneck=4
)
micro = torch.randn(2, 50, 128)
macro = torch.randn(2, 4, 128)
latent, micro_e, macro_e = fmhca(micro, macro)
print(f" Micro input: {list(micro.shape)}")
print(f" Macro input: {list(macro.shape)}")
print(f" Latent output: {list(latent.shape)}")
print(f" Micro enriched: {list(micro_e.shape)}")
print(f" Macro enriched: {list(macro_e.shape)}")
print(f" Gate values: Ξ±={fmhca.micro_to_macro.alpha.item():.4f}, "
f"Ξ²={fmhca.macro_to_micro.alpha.item():.4f}")
def test_moe_layer():
"""Test MoE routing and expert utilization."""
print("\n" + "=" * 60)
print("TEST 8: Sparse Mixture of Experts")
print("=" * 60)
from fin_moe.moe_layer import SparseMoELayer
moe = SparseMoELayer(d_model=128, d_ff=512, n_experts=8, top_k=2)
moe.train()
x = torch.randn(4, 8, 128)
out = moe(x)
print(f" Input: {list(x.shape)}")
print(f" Output: {list(out.shape)}")
print(f" Aux loss: {moe.aux_loss.item():.4f}")
# Run a few batches to accumulate utilization stats
for _ in range(10):
moe(torch.randn(4, 8, 128))
util = moe.get_expert_utilization()
print(f" Expert utilization: {[f'{u:.3f}' for u in util['utilization']]}")
print(f" Utilization entropy: {util['entropy']:.4f}")
def test_hybrid_loss():
"""Test Sakuma DML hybrid loss."""
print("\n" + "=" * 60)
print("TEST 9: Sakuma DML Hybrid Loss")
print("=" * 60)
from fin_moe.hybrid_loss import SakumaDMLHybridLoss
loss_fn = SakumaDMLHybridLoss()
B = 8
return_pred = torch.randn(B)
direction_logits = torch.randn(B, 3)
toxicity_pred = torch.cat([
torch.sigmoid(torch.randn(B, 1)),
torch.randn(B, 1)
], dim=-1)
true_returns = torch.randn(B) * 0.01
true_toxicity = torch.rand(B)
total_loss, loss_dict = loss_fn(
return_pred, direction_logits, toxicity_pred,
true_returns, true_toxicity,
moe_aux_loss=torch.tensor(0.1)
)
print(f" Total loss: {total_loss.item():.4f}")
for k, v in loss_dict.items():
print(f" {k:25s}: {v:.4f}")
# Check that loss weights are learnable
total_loss.backward()
print(f"\n Gradients on log-var parameters:")
print(f" log_var_mse grad: {loss_fn.log_var_mse.grad.item():.6f}")
print(f" log_var_direction grad: {loss_fn.log_var_direction.grad.item():.6f}")
print(f" log_var_toxicity grad: {loss_fn.log_var_toxicity.grad.item():.6f}")
def test_inference_mode():
"""Test inference mode (get_latent)."""
print("\n" + "=" * 60)
print("TEST 10: Inference Mode (get_latent)")
print("=" * 60)
from fin_moe.model import FinMoELatentEncoder
from fin_moe.data_pipeline import generate_synthetic_batch
model = FinMoELatentEncoder()
model.eval()
batch = generate_synthetic_batch(batch_size=4, micro_seq_len=50, macro_seq_len=16)
with torch.no_grad():
latent = model.get_latent(
bid_volumes=batch["bid_volumes"],
ask_volumes=batch["ask_volumes"],
text_embeddings=batch["text_embeddings"]
)
print(f" Latent shape: {list(latent.shape)}")
print(f" Latent norm: {latent.norm(dim=-1).mean().item():.4f}")
print(f" Latent std: {latent.std(dim=-1).mean().item():.4f}")
def test_training_step():
"""Test a complete training step."""
print("\n" + "=" * 60)
print("TEST 11: Complete Training Step")
print("=" * 60)
from fin_moe.model import FinMoELatentEncoder
from fin_moe.data_pipeline import generate_synthetic_batch
model = FinMoELatentEncoder({"d_model": 128, "n_experts": 4})
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
model.train()
losses = []
for step in range(5):
batch = generate_synthetic_batch(batch_size=4, micro_seq_len=50, macro_seq_len=16)
outputs = model(**batch)
loss = outputs["loss"]
optimizer.zero_grad()
loss.backward()
# Gradient clipping
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
losses.append(loss.item())
print(f" Step {step+1}: loss={loss.item():.4f}, grad_norm={grad_norm:.4f}")
print(f"\n Loss trend: {losses[0]:.4f} β†’ {losses[-1]:.4f}")
decreasing = losses[-1] < losses[0]
print(f" Loss decreasing: {'βœ“' if decreasing else '~ (expected with random data)'}")
def main():
"""Run all tests."""
print("β•”" + "═" * 58 + "β•—")
print("β•‘ Fin-MoE Latent Encoder β€” End-to-End Test Suite β•‘")
print("β•š" + "═" * 58 + "╝")
tests = [
("Model Instantiation", test_model_instantiation),
("Forward Pass", lambda: test_forward_pass(model)),
("Backward Pass", lambda: test_backward_pass(model, outputs)),
("MLOFI Extractor", test_mlofi_extractor),
("Active Depth", test_active_depth),
("Sentiment Tokenizer", test_sentiment_tokenizer),
("Cross-Attention", test_cross_attention),
("MoE Layer", test_moe_layer),
("Hybrid Loss", test_hybrid_loss),
("Inference Mode", test_inference_mode),
("Training Step", test_training_step),
]
# Run test 1 first to get model
model = test_model_instantiation()
# Run test 2 with model
outputs = test_forward_pass(model)
# Run test 3 with model and outputs
test_backward_pass(model, outputs)
# Run remaining tests
passed = 3
failed = 0
for name, test_fn in tests[3:]:
try:
test_fn()
passed += 1
except Exception as e:
failed += 1
print(f"\n βœ— FAILED: {name}")
print(f" Error: {e}")
traceback.print_exc()
print("\n" + "=" * 60)
print(f"RESULTS: {passed} passed, {failed} failed out of {len(tests)}")
print("=" * 60)
if failed > 0:
sys.exit(1)
else:
print("\nβœ“ All tests passed! Fin-MoE Latent Encoder is working correctly.")
sys.exit(0)
if __name__ == "__main__":
main()