File size: 11,165 Bytes
d29b763
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
#!/usr/bin/env python3
"""Comprehensive production integration test and validation suite.

Tests:
1. Backend API startup and health checks
2. Model inference on known and unknown pairs
3. Frontend API contract compliance
4. Healthcare safety features
5. Confidence calibration
6. Error handling
"""
from __future__ import annotations

import asyncio
import json
import logging
import subprocess
import sys
import time
from pathlib import Path
from typing import Any, Dict

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
)
logger = logging.getLogger('medcare_ddi.integration_test')


class IntegrationTest:
    """Complete integration test suite."""

    def __init__(self):
        self.results = {}
        self.base_url = 'http://localhost:8000'

    def log_test(self, name: str, passed: bool, details: str = '') -> None:
        """Log test result."""
        status = 'βœ“ PASS' if passed else 'βœ— FAIL'
        logger.info(f'{status} - {name}')
        if details:
            logger.info(f'       {details}')
        self.results[name] = passed

    def test_health_endpoint(self) -> bool:
        """Test /health endpoint."""
        try:
            import requests

            response = requests.get(f'{self.base_url}/health', timeout=5)

            if response.status_code != 200:
                self.log_test('Health Endpoint', False, f'Status {response.status_code}')
                return False

            data = response.json()
            checks = [
                ('status' in data, 'status field'),
                ('model_loaded' in data, 'model_loaded field'),
                ('pairs_loaded' in data, 'pairs_loaded field'),
                (data.get('model_loaded') is True, 'model_loaded is True'),
                (data.get('pairs_loaded', 0) > 0, f'pairs_loaded > 0 (got {data.get("pairs_loaded")})'),
            ]

            all_passed = all(check[0] for check in checks)
            details = ', '.join(check[1] for check in checks if check[0])
            self.log_test('Health Endpoint', all_passed, details)
            return all_passed

        except Exception as e:
            self.log_test('Health Endpoint', False, str(e))
            return False

    def test_known_interactions(self) -> bool:
        """Test predictions on known DDI pairs."""
        try:
            import requests

            test_pairs = [
                ('Aspirin', 'Warfarin', 'major'),
                ('Metformin', 'Insulin', 'moderate'),
            ]

            all_passed = True
            for drug_a, drug_b, expected_severity in test_pairs:
                try:
                    response = requests.post(
                        f'{self.base_url}/predict',
                        json={'drug_a': drug_a, 'drug_b': drug_b},
                        timeout=10,
                    )

                    if response.status_code != 200:
                        self.log_test(
                            f'Known DDI: {drug_a} + {drug_b}',
                            False,
                            f'Status {response.status_code}',
                        )
                        all_passed = False
                        continue

                    data = response.json()

                    # Check response schema
                    required_fields = [
                        'drug_a',
                        'drug_b',
                        'severity',
                        'confidence',
                        'confidence_band',
                        'source',
                        'explanation',
                        'clinical_advice',
                        'latency_ms',
                    ]

                    missing_fields = [f for f in required_fields if f not in data]

                    if missing_fields:
                        self.log_test(
                            f'Known DDI: {drug_a} + {drug_b}',
                            False,
                            f'Missing fields: {missing_fields}',
                        )
                        all_passed = False
                        continue

                    # Check values
                    severity = data.get('severity')
                    confidence = data.get('confidence', 0)
                    source = data.get('source')

                    self.log_test(
                        f'Known DDI: {drug_a} + {drug_b}',
                        True,
                        f'{severity} (conf={confidence:.2f}, src={source})',
                    )

                except Exception as e:
                    self.log_test(f'Known DDI: {drug_a} + {drug_b}', False, str(e))
                    all_passed = False

            return all_passed

        except Exception as e:
            self.log_test('Known Interactions Test', False, str(e))
            return False

    def test_unseen_pairs(self) -> bool:
        """Test ML fallback on unseen pairs."""
        try:
            import requests

            test_pairs = [
                ('UnknownDrugA', 'UnknownDrugB'),
                ('TestDrug1', 'TestDrug2'),
            ]

            all_passed = True
            for drug_a, drug_b in test_pairs:
                try:
                    response = requests.post(
                        f'{self.base_url}/predict',
                        json={'drug_a': drug_a, 'drug_b': drug_b},
                        timeout=10,
                    )

                    if response.status_code != 200:
                        self.log_test(
                            f'Unseen pair: {drug_a} + {drug_b}',
                            False,
                            f'Status {response.status_code}',
                        )
                        all_passed = False
                        continue

                    data = response.json()
                    severity = data.get('severity')
                    source = data.get('source')

                    self.log_test(
                        f'Unseen pair: {drug_a} + {drug_b}',
                        True,
                        f'{severity} (source={source})',
                    )

                except Exception as e:
                    self.log_test(f'Unseen pair: {drug_a} + {drug_b}', False, str(e))
                    all_passed = False

            return all_passed

        except Exception as e:
            self.log_test('Unseen Pairs Test', False, str(e))
            return False

    def test_error_handling(self) -> bool:
        """Test error handling for invalid inputs."""
        try:
            import requests

            test_cases = [
                ({}, 'Missing both drugs'),
                ({'drug_a': ''}, 'Empty drug names'),
                ({'drug_a': None, 'drug_b': None}, 'None drugs'),
            ]

            all_passed = True
            for payload, desc in test_cases:
                try:
                    response = requests.post(
                        f'{self.base_url}/predict',
                        json=payload,
                        timeout=5,
                    )

                    if response.status_code >= 400:
                        self.log_test(f'Error Handling: {desc}', True, f'Status {response.status_code}')
                    else:
                        self.log_test(f'Error Handling: {desc}', False, 'Should have failed')
                        all_passed = False

                except Exception as e:
                    self.log_test(f'Error Handling: {desc}', False, str(e))
                    all_passed = False

            return all_passed

        except Exception as e:
            self.log_test('Error Handling Test', False, str(e))
            return False

    def test_confidence_bands(self) -> bool:
        """Test confidence band classification."""
        try:
            import requests

            response = requests.get(f'{self.base_url}/health', timeout=5)
            if response.status_code != 200:
                self.log_test('Confidence Bands', False, 'Could not get health info')
                return False

            # Make several predictions and check confidence_band values
            test_pairs = [('Aspirin', 'Warfarin'), ('Drug1', 'Drug2'), ('Drug3', 'Drug4')]

            bands_found = set()
            for drug_a, drug_b in test_pairs:
                try:
                    response = requests.post(
                        f'{self.base_url}/predict',
                        json={'drug_a': drug_a, 'drug_b': drug_b},
                        timeout=10,
                    )

                    if response.status_code == 200:
                        data = response.json()
                        band = data.get('confidence_band')
                        if band in ['high', 'medium', 'low']:
                            bands_found.add(band)
                except:
                    pass

            valid_bands = len(bands_found) > 0
            details = f'Bands found: {bands_found}' if bands_found else 'No valid bands found'
            self.log_test('Confidence Bands', valid_bands, details)
            return valid_bands

        except Exception as e:
            self.log_test('Confidence Bands Test', False, str(e))
            return False

    def run_all(self) -> bool:
        """Run all tests."""
        logger.info('')
        logger.info('β•”' + '═'*68 + 'β•—')
        logger.info('β•‘ MEDCARE-DDI INTEGRATION TEST SUITE' + ' '*33 + 'β•‘')
        logger.info('β•š' + '═'*68 + '╝')
        logger.info('')

        tests = [
            self.test_health_endpoint,
            self.test_known_interactions,
            self.test_unseen_pairs,
            self.test_error_handling,
            self.test_confidence_bands,
        ]

        for test in tests:
            try:
                test()
            except Exception as e:
                logger.error(f'Test {test.__name__} crashed: {e}', exc_info=True)

        # Summary
        logger.info('')
        logger.info('='*70)
        logger.info('TEST SUMMARY')
        logger.info('='*70)

        passed = sum(1 for v in self.results.values() if v)
        total = len(self.results)
        pass_rate = (passed / total * 100) if total > 0 else 0

        logger.info(f'Passed: {passed}/{total} ({pass_rate:.0f}%)')

        for test_name, passed in self.results.items():
            status = 'βœ“' if passed else 'βœ—'
            logger.info(f'{status} {test_name}')

        logger.info('')
        all_passed = all(self.results.values())
        status = 'READY FOR DEPLOYMENT' if all_passed else 'NEEDS_ATTENTION'
        logger.info(f'Overall: {status}')
        logger.info('')

        return all_passed


def main():
    """Run integration tests."""
    # Try to import requests
    try:
        import requests
    except ImportError:
        logger.error('requests module not found. Install with: pip install requests')
        return False

    # Run tests
    test_suite = IntegrationTest()
    return test_suite.run_all()


if __name__ == '__main__':
    success = main()
    sys.exit(0 if success else 1)