File size: 3,816 Bytes
d29b763
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Unit-style tests for multimodal DDI components.

Run:
  python src/validation/test_multimodal_components.py
"""
from __future__ import annotations

import tempfile
import unittest
from pathlib import Path

import numpy as np
import pandas as pd

from training.embeddings import _hash_inputs
from training.calibration import SafetyCalibrationConfig, expected_calibration_error, risk_adjusted_prediction
from training.ensemble import EnsemblePredictor, train_base_models
from training.molecular_features import MolecularFeatureExtractor, smiles_to_features
from preprocessing.unified_dataset_builder import dedupe_and_resolve


class TestMolecularFeatures(unittest.TestCase):
    def test_smiles_to_features_shape(self):
        X, meta = smiles_to_features(['CCO', 'C1=CC=CC=C1', 'INVALID'])
        self.assertEqual(X.shape[0], 3)
        self.assertGreater(X.shape[1], 100)
        self.assertEqual(len(meta), 3)

    def test_pair_extractor_shape(self):
        ext = MolecularFeatureExtractor()
        X = ext.fit_transform([('CCO', 'CCN'), ('C1=CC=CC=C1', 'INVALID')])
        self.assertEqual(X.shape[0], 2)
        self.assertGreater(X.shape[1], 100)


class TestEmbeddingHelpers(unittest.TestCase):
    def test_hash_stability(self):
        a = _hash_inputs(['aspirin', 'warfarin'])
        b = _hash_inputs(['aspirin', 'warfarin'])
        c = _hash_inputs(['warfarin', 'aspirin'])
        self.assertEqual(a, b)
        self.assertNotEqual(a, c)


class TestEnsemblePipeline(unittest.TestCase):
    def test_train_and_predict(self):
        rng = np.random.default_rng(42)
        X = rng.normal(size=(120, 32)).astype(np.float32)
        y = rng.integers(0, 4, size=(120,), endpoint=False)

        with tempfile.TemporaryDirectory() as td:
            out = Path(td)
            train_base_models(X, y, out)
            predictor = EnsemblePredictor(out)
            pred = predictor.predict(X[:10])
            self.assertIn('preds', pred)
            self.assertIn('probs', pred)
            self.assertEqual(len(pred['preds']), 10)
            self.assertEqual(pred['probs'].shape[0], 10)


class TestCalibrationSafety(unittest.TestCase):
    def test_ece_range(self):
        y_true = np.array([0, 1, 2, 3], dtype=np.int64)
        probs = np.array(
            [
                [0.8, 0.1, 0.05, 0.05],
                [0.2, 0.7, 0.05, 0.05],
                [0.1, 0.2, 0.6, 0.1],
                [0.1, 0.1, 0.15, 0.65],
            ],
            dtype=np.float32,
        )
        ece = expected_calibration_error(y_true, probs, n_bins=5)
        self.assertGreaterEqual(ece, 0.0)
        self.assertLessEqual(ece, 1.0)

    def test_safety_escalation(self):
        cfg = SafetyCalibrationConfig(
            severe_class_index=3,
            severe_alert_threshold=0.45,
            entropy_high_threshold=0.7,
            top2_margin_low_threshold=0.2,
        )
        probs = np.array([0.18, 0.19, 0.20, 0.43], dtype=np.float32)
        out = risk_adjusted_prediction(probs, cfg=cfg)
        self.assertEqual(out['pred_index'], 3)
        self.assertTrue(out['escalated_for_safety'])


class TestUnifiedConflictResolution(unittest.TestCase):
    def test_conflict_field_present(self):
        df = pd.DataFrame(
            [
                {'drug_a': 'A', 'drug_b': 'B', 'severity': 'minor', 'source': 'faers', 'support': 1, 'evidence': ''},
                {'drug_a': 'B', 'drug_b': 'A', 'severity': 'major', 'source': 'drugbank', 'support': 1, 'evidence': ''},
            ]
        )
        out = dedupe_and_resolve(df)
        self.assertIn('conflict', out.columns)
        self.assertEqual(int(out.iloc[0]['conflict']), 1)
        self.assertIn(str(out.iloc[0]['severity']), {'major', 'moderate', 'minor', 'unknown'})


if __name__ == '__main__':
    unittest.main(verbosity=2)