File size: 6,657 Bytes
f61ae9f 624473d f61ae9f 624473d f61ae9f 624473d f61ae9f 624473d f61ae9f 624473d f61ae9f 624473d f61ae9f 624473d f61ae9f 624473d f61ae9f 624473d f61ae9f 624473d f61ae9f 624473d f61ae9f 624473d f61ae9f 624473d f61ae9f 624473d f61ae9f 624473d f61ae9f 624473d f61ae9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
"""
Inference script for Simplified ASD Detector (8 features)
Example usage for making predictions with the trained model.
"""
import sys
import json
import torch
import joblib
import pandas as pd
from pathlib import Path
from model import SimplifiedASDDetector, load_model, FEATURES, SimplePreprocessor
# Fix for unpickling preprocessor saved from different module
sys.modules['__main__'].SimplePreprocessor = SimplePreprocessor
# Original column names (as in the training data)
ORIGINAL_COLUMN_NAMES = [
'Developmental milestones- global delay (G), motor delay (M), cognitive delay (C)',
'IQ/DQ',
'ICD',
'Language disorder Y= present, N=absent',
'Language development: delay, normal=N, absent=A',
'Dysmorphysm y=present, no=absent',
'Behaviour disorder- agressivity, agitation, irascibility',
'Neurological Examination; N=normal, text = abnormal; free cell = examination not performed ???'
]
# Simplified names for user-friendly input
SIMPLE_NAMES = [
'developmental_milestones',
'iq_dq',
'intellectual_disability',
'language_disorder',
'language_development',
'dysmorphism',
'behaviour_disorder',
'neurological_exam'
]
class ASDPredictor:
"""
Easy-to-use predictor for ASD detection.
Example:
>>> predictor = ASDPredictor('.')
>>> result = predictor.predict({
... 'developmental_milestones': 'N',
... 'iq_dq': 100,
... 'intellectual_disability': 'N',
... 'language_disorder': 'N',
... 'language_development': 'N',
... 'dysmorphism': 'NO',
... 'behaviour_disorder': 'N',
... 'neurological_exam': 'N'
... })
>>> print(result['prediction']) # 'Healthy' or 'ASD'
"""
def __init__(self, model_dir='.', device='cpu'):
model_dir = Path(model_dir)
self.device = device
# Load model
self.model = load_model(model_dir / 'autism_detector.pth', device)
# Load preprocessor
self.preprocessor = joblib.load(model_dir / 'preprocessor.joblib')
# Load config
with open(model_dir / 'config.json', 'r') as f:
self.config = json.load(f)
def _convert_simple_to_original(self, data):
"""Convert simplified feature names to original column names."""
if isinstance(data, dict):
converted = {}
for simple, original in zip(SIMPLE_NAMES, ORIGINAL_COLUMN_NAMES):
if simple in data:
converted[original] = data[simple]
elif original in data:
converted[original] = data[original]
return converted
return data
def predict(self, data):
"""
Make prediction on patient data.
Args:
data (dict): Patient features using simplified names:
- developmental_milestones: N/G/M/C
- iq_dq: numeric (e.g., 100)
- intellectual_disability: N/F70.0/F71/F72
- language_disorder: N/Y
- language_development: N/delay/A
- dysmorphism: NO/Y
- behaviour_disorder: N/Y
- neurological_exam: N or abnormal description
Returns:
dict: {
'prediction': 'Healthy' or 'ASD',
'probability_asd': float,
'probability_healthy': float,
'confidence': float
}
"""
# Convert to original column names
converted = self._convert_simple_to_original(data)
df = pd.DataFrame([converted])
# Preprocess
X = self.preprocessor.transform(df)
X_tensor = torch.FloatTensor(X).to(self.device)
# Predict
self.model.eval()
with torch.no_grad():
prob_asd = self.model(X_tensor).cpu().item()
return {
'prediction': 'ASD' if prob_asd > 0.5 else 'Healthy',
'label': 1 if prob_asd > 0.5 else 0,
'probability_asd': prob_asd,
'probability_healthy': 1 - prob_asd,
'confidence': max(prob_asd, 1 - prob_asd)
}
@staticmethod
def get_feature_info():
"""Return information about required features."""
return FEATURES
def main():
"""Example usage."""
print("=" * 60)
print("ASD Detector - Simplified 8-Feature Model")
print("=" * 60)
predictor = ASDPredictor('.')
# Example 1: Healthy child profile
print("\n--- Example 1: Healthy Child ---")
healthy_child = {
'developmental_milestones': 'N', # Normal
'iq_dq': 105, # Normal IQ
'intellectual_disability': 'N', # None
'language_disorder': 'N', # No
'language_development': 'N', # Normal
'dysmorphism': 'NO', # Absent
'behaviour_disorder': 'N', # No
'neurological_exam': 'N' # Normal
}
print("Input:")
for k, v in healthy_child.items():
print(f" {k}: {v}")
result = predictor.predict(healthy_child)
print(f"\nResult: {result['prediction']}")
print(f" Probability ASD: {result['probability_asd']:.2%}")
print(f" Confidence: {result['confidence']:.2%}")
# Example 2: Child with developmental concerns
print("\n--- Example 2: Child with Developmental Concerns ---")
concerning_child = {
'developmental_milestones': 'G', # Global delay
'iq_dq': 55, # Below average
'intellectual_disability': 'F70.0', # Mild
'language_disorder': 'Y', # Yes
'language_development': 'delay', # Delayed
'dysmorphism': 'NO', # Absent
'behaviour_disorder': 'Y', # Yes
'neurological_exam': 'N' # Normal
}
print("Input:")
for k, v in concerning_child.items():
print(f" {k}: {v}")
result = predictor.predict(concerning_child)
print(f"\nResult: {result['prediction']}")
print(f" Probability ASD: {result['probability_asd']:.2%}")
print(f" Confidence: {result['confidence']:.2%}")
# Print feature reference
print("\n" + "=" * 60)
print("FEATURE REFERENCE")
print("=" * 60)
for name, info in FEATURES.items():
print(f"\n{name}:")
print(f" {info['description']}")
if isinstance(info['values'], dict):
for k, v in info['values'].items():
print(f" '{k}' = {v}")
else:
print(f" {info['values']}")
if __name__ == '__main__':
main()
|