Spaces:
Running
Running
| """Unit-style tests for multimodal DDI components. | |
| Run: | |
| python src/validation/test_multimodal_components.py | |
| """ | |
| from __future__ import annotations | |
| import tempfile | |
| import unittest | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| from training.embeddings import _hash_inputs | |
| from training.calibration import SafetyCalibrationConfig, expected_calibration_error, risk_adjusted_prediction | |
| from training.ensemble import EnsemblePredictor, train_base_models | |
| from training.molecular_features import MolecularFeatureExtractor, smiles_to_features | |
| from preprocessing.unified_dataset_builder import dedupe_and_resolve | |
| class TestMolecularFeatures(unittest.TestCase): | |
| def test_smiles_to_features_shape(self): | |
| X, meta = smiles_to_features(['CCO', 'C1=CC=CC=C1', 'INVALID']) | |
| self.assertEqual(X.shape[0], 3) | |
| self.assertGreater(X.shape[1], 100) | |
| self.assertEqual(len(meta), 3) | |
| def test_pair_extractor_shape(self): | |
| ext = MolecularFeatureExtractor() | |
| X = ext.fit_transform([('CCO', 'CCN'), ('C1=CC=CC=C1', 'INVALID')]) | |
| self.assertEqual(X.shape[0], 2) | |
| self.assertGreater(X.shape[1], 100) | |
| class TestEmbeddingHelpers(unittest.TestCase): | |
| def test_hash_stability(self): | |
| a = _hash_inputs(['aspirin', 'warfarin']) | |
| b = _hash_inputs(['aspirin', 'warfarin']) | |
| c = _hash_inputs(['warfarin', 'aspirin']) | |
| self.assertEqual(a, b) | |
| self.assertNotEqual(a, c) | |
| class TestEnsemblePipeline(unittest.TestCase): | |
| def test_train_and_predict(self): | |
| rng = np.random.default_rng(42) | |
| X = rng.normal(size=(120, 32)).astype(np.float32) | |
| y = rng.integers(0, 4, size=(120,), endpoint=False) | |
| with tempfile.TemporaryDirectory() as td: | |
| out = Path(td) | |
| train_base_models(X, y, out) | |
| predictor = EnsemblePredictor(out) | |
| pred = predictor.predict(X[:10]) | |
| self.assertIn('preds', pred) | |
| self.assertIn('probs', pred) | |
| self.assertEqual(len(pred['preds']), 10) | |
| self.assertEqual(pred['probs'].shape[0], 10) | |
| class TestCalibrationSafety(unittest.TestCase): | |
| def test_ece_range(self): | |
| y_true = np.array([0, 1, 2, 3], dtype=np.int64) | |
| probs = np.array( | |
| [ | |
| [0.8, 0.1, 0.05, 0.05], | |
| [0.2, 0.7, 0.05, 0.05], | |
| [0.1, 0.2, 0.6, 0.1], | |
| [0.1, 0.1, 0.15, 0.65], | |
| ], | |
| dtype=np.float32, | |
| ) | |
| ece = expected_calibration_error(y_true, probs, n_bins=5) | |
| self.assertGreaterEqual(ece, 0.0) | |
| self.assertLessEqual(ece, 1.0) | |
| def test_safety_escalation(self): | |
| cfg = SafetyCalibrationConfig( | |
| severe_class_index=3, | |
| severe_alert_threshold=0.45, | |
| entropy_high_threshold=0.7, | |
| top2_margin_low_threshold=0.2, | |
| ) | |
| probs = np.array([0.18, 0.19, 0.20, 0.43], dtype=np.float32) | |
| out = risk_adjusted_prediction(probs, cfg=cfg) | |
| self.assertEqual(out['pred_index'], 3) | |
| self.assertTrue(out['escalated_for_safety']) | |
| class TestUnifiedConflictResolution(unittest.TestCase): | |
| def test_conflict_field_present(self): | |
| df = pd.DataFrame( | |
| [ | |
| {'drug_a': 'A', 'drug_b': 'B', 'severity': 'minor', 'source': 'faers', 'support': 1, 'evidence': ''}, | |
| {'drug_a': 'B', 'drug_b': 'A', 'severity': 'major', 'source': 'drugbank', 'support': 1, 'evidence': ''}, | |
| ] | |
| ) | |
| out = dedupe_and_resolve(df) | |
| self.assertIn('conflict', out.columns) | |
| self.assertEqual(int(out.iloc[0]['conflict']), 1) | |
| self.assertIn(str(out.iloc[0]['severity']), {'major', 'moderate', 'minor', 'unknown'}) | |
| if __name__ == '__main__': | |
| unittest.main(verbosity=2) | |