Spaces:
Running
Running
File size: 3,816 Bytes
d29b763 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | """Unit-style tests for multimodal DDI components.
Run:
python src/validation/test_multimodal_components.py
"""
from __future__ import annotations
import tempfile
import unittest
from pathlib import Path
import numpy as np
import pandas as pd
from training.embeddings import _hash_inputs
from training.calibration import SafetyCalibrationConfig, expected_calibration_error, risk_adjusted_prediction
from training.ensemble import EnsemblePredictor, train_base_models
from training.molecular_features import MolecularFeatureExtractor, smiles_to_features
from preprocessing.unified_dataset_builder import dedupe_and_resolve
class TestMolecularFeatures(unittest.TestCase):
def test_smiles_to_features_shape(self):
X, meta = smiles_to_features(['CCO', 'C1=CC=CC=C1', 'INVALID'])
self.assertEqual(X.shape[0], 3)
self.assertGreater(X.shape[1], 100)
self.assertEqual(len(meta), 3)
def test_pair_extractor_shape(self):
ext = MolecularFeatureExtractor()
X = ext.fit_transform([('CCO', 'CCN'), ('C1=CC=CC=C1', 'INVALID')])
self.assertEqual(X.shape[0], 2)
self.assertGreater(X.shape[1], 100)
class TestEmbeddingHelpers(unittest.TestCase):
def test_hash_stability(self):
a = _hash_inputs(['aspirin', 'warfarin'])
b = _hash_inputs(['aspirin', 'warfarin'])
c = _hash_inputs(['warfarin', 'aspirin'])
self.assertEqual(a, b)
self.assertNotEqual(a, c)
class TestEnsemblePipeline(unittest.TestCase):
def test_train_and_predict(self):
rng = np.random.default_rng(42)
X = rng.normal(size=(120, 32)).astype(np.float32)
y = rng.integers(0, 4, size=(120,), endpoint=False)
with tempfile.TemporaryDirectory() as td:
out = Path(td)
train_base_models(X, y, out)
predictor = EnsemblePredictor(out)
pred = predictor.predict(X[:10])
self.assertIn('preds', pred)
self.assertIn('probs', pred)
self.assertEqual(len(pred['preds']), 10)
self.assertEqual(pred['probs'].shape[0], 10)
class TestCalibrationSafety(unittest.TestCase):
def test_ece_range(self):
y_true = np.array([0, 1, 2, 3], dtype=np.int64)
probs = np.array(
[
[0.8, 0.1, 0.05, 0.05],
[0.2, 0.7, 0.05, 0.05],
[0.1, 0.2, 0.6, 0.1],
[0.1, 0.1, 0.15, 0.65],
],
dtype=np.float32,
)
ece = expected_calibration_error(y_true, probs, n_bins=5)
self.assertGreaterEqual(ece, 0.0)
self.assertLessEqual(ece, 1.0)
def test_safety_escalation(self):
cfg = SafetyCalibrationConfig(
severe_class_index=3,
severe_alert_threshold=0.45,
entropy_high_threshold=0.7,
top2_margin_low_threshold=0.2,
)
probs = np.array([0.18, 0.19, 0.20, 0.43], dtype=np.float32)
out = risk_adjusted_prediction(probs, cfg=cfg)
self.assertEqual(out['pred_index'], 3)
self.assertTrue(out['escalated_for_safety'])
class TestUnifiedConflictResolution(unittest.TestCase):
def test_conflict_field_present(self):
df = pd.DataFrame(
[
{'drug_a': 'A', 'drug_b': 'B', 'severity': 'minor', 'source': 'faers', 'support': 1, 'evidence': ''},
{'drug_a': 'B', 'drug_b': 'A', 'severity': 'major', 'source': 'drugbank', 'support': 1, 'evidence': ''},
]
)
out = dedupe_and_resolve(df)
self.assertIn('conflict', out.columns)
self.assertEqual(int(out.iloc[0]['conflict']), 1)
self.assertIn(str(out.iloc[0]['severity']), {'major', 'moderate', 'minor', 'unknown'})
if __name__ == '__main__':
unittest.main(verbosity=2)
|