|
|
""" |
|
|
Inference script for Simplified ASD Detector (8 features) |
|
|
|
|
|
Example usage for making predictions with the trained model. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import json |
|
|
import torch |
|
|
import joblib |
|
|
import pandas as pd |
|
|
from pathlib import Path |
|
|
from model import SimplifiedASDDetector, load_model, FEATURES, SimplePreprocessor |
|
|
|
|
|
|
|
|
sys.modules['__main__'].SimplePreprocessor = SimplePreprocessor |
|
|
|
|
|
|
|
|
|
|
|
ORIGINAL_COLUMN_NAMES = [ |
|
|
'Developmental milestones- global delay (G), motor delay (M), cognitive delay (C)', |
|
|
'IQ/DQ', |
|
|
'ICD', |
|
|
'Language disorder Y= present, N=absent', |
|
|
'Language development: delay, normal=N, absent=A', |
|
|
'Dysmorphysm y=present, no=absent', |
|
|
'Behaviour disorder- agressivity, agitation, irascibility', |
|
|
'Neurological Examination; N=normal, text = abnormal; free cell = examination not performed ???' |
|
|
] |
|
|
|
|
|
|
|
|
SIMPLE_NAMES = [ |
|
|
'developmental_milestones', |
|
|
'iq_dq', |
|
|
'intellectual_disability', |
|
|
'language_disorder', |
|
|
'language_development', |
|
|
'dysmorphism', |
|
|
'behaviour_disorder', |
|
|
'neurological_exam' |
|
|
] |
|
|
|
|
|
|
|
|
class ASDPredictor: |
|
|
""" |
|
|
Easy-to-use predictor for ASD detection. |
|
|
|
|
|
Example: |
|
|
>>> predictor = ASDPredictor('.') |
|
|
>>> result = predictor.predict({ |
|
|
... 'developmental_milestones': 'N', |
|
|
... 'iq_dq': 100, |
|
|
... 'intellectual_disability': 'N', |
|
|
... 'language_disorder': 'N', |
|
|
... 'language_development': 'N', |
|
|
... 'dysmorphism': 'NO', |
|
|
... 'behaviour_disorder': 'N', |
|
|
... 'neurological_exam': 'N' |
|
|
... }) |
|
|
>>> print(result['prediction']) # 'Healthy' or 'ASD' |
|
|
""" |
|
|
|
|
|
def __init__(self, model_dir='.', device='cpu'): |
|
|
model_dir = Path(model_dir) |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self.model = load_model(model_dir / 'autism_detector.pth', device) |
|
|
|
|
|
|
|
|
self.preprocessor = joblib.load(model_dir / 'preprocessor.joblib') |
|
|
|
|
|
|
|
|
with open(model_dir / 'config.json', 'r') as f: |
|
|
self.config = json.load(f) |
|
|
|
|
|
def _convert_simple_to_original(self, data): |
|
|
"""Convert simplified feature names to original column names.""" |
|
|
if isinstance(data, dict): |
|
|
converted = {} |
|
|
for simple, original in zip(SIMPLE_NAMES, ORIGINAL_COLUMN_NAMES): |
|
|
if simple in data: |
|
|
converted[original] = data[simple] |
|
|
elif original in data: |
|
|
converted[original] = data[original] |
|
|
return converted |
|
|
return data |
|
|
|
|
|
def predict(self, data): |
|
|
""" |
|
|
Make prediction on patient data. |
|
|
|
|
|
Args: |
|
|
data (dict): Patient features using simplified names: |
|
|
- developmental_milestones: N/G/M/C |
|
|
- iq_dq: numeric (e.g., 100) |
|
|
- intellectual_disability: N/F70.0/F71/F72 |
|
|
- language_disorder: N/Y |
|
|
- language_development: N/delay/A |
|
|
- dysmorphism: NO/Y |
|
|
- behaviour_disorder: N/Y |
|
|
- neurological_exam: N or abnormal description |
|
|
|
|
|
Returns: |
|
|
dict: { |
|
|
'prediction': 'Healthy' or 'ASD', |
|
|
'probability_asd': float, |
|
|
'probability_healthy': float, |
|
|
'confidence': float |
|
|
} |
|
|
""" |
|
|
|
|
|
converted = self._convert_simple_to_original(data) |
|
|
df = pd.DataFrame([converted]) |
|
|
|
|
|
|
|
|
X = self.preprocessor.transform(df) |
|
|
X_tensor = torch.FloatTensor(X).to(self.device) |
|
|
|
|
|
|
|
|
self.model.eval() |
|
|
with torch.no_grad(): |
|
|
prob_asd = self.model(X_tensor).cpu().item() |
|
|
|
|
|
return { |
|
|
'prediction': 'ASD' if prob_asd > 0.5 else 'Healthy', |
|
|
'label': 1 if prob_asd > 0.5 else 0, |
|
|
'probability_asd': prob_asd, |
|
|
'probability_healthy': 1 - prob_asd, |
|
|
'confidence': max(prob_asd, 1 - prob_asd) |
|
|
} |
|
|
|
|
|
@staticmethod |
|
|
def get_feature_info(): |
|
|
"""Return information about required features.""" |
|
|
return FEATURES |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Example usage.""" |
|
|
print("=" * 60) |
|
|
print("ASD Detector - Simplified 8-Feature Model") |
|
|
print("=" * 60) |
|
|
|
|
|
predictor = ASDPredictor('.') |
|
|
|
|
|
|
|
|
print("\n--- Example 1: Healthy Child ---") |
|
|
healthy_child = { |
|
|
'developmental_milestones': 'N', |
|
|
'iq_dq': 105, |
|
|
'intellectual_disability': 'N', |
|
|
'language_disorder': 'N', |
|
|
'language_development': 'N', |
|
|
'dysmorphism': 'NO', |
|
|
'behaviour_disorder': 'N', |
|
|
'neurological_exam': 'N' |
|
|
} |
|
|
|
|
|
print("Input:") |
|
|
for k, v in healthy_child.items(): |
|
|
print(f" {k}: {v}") |
|
|
|
|
|
result = predictor.predict(healthy_child) |
|
|
print(f"\nResult: {result['prediction']}") |
|
|
print(f" Probability ASD: {result['probability_asd']:.2%}") |
|
|
print(f" Confidence: {result['confidence']:.2%}") |
|
|
|
|
|
|
|
|
print("\n--- Example 2: Child with Developmental Concerns ---") |
|
|
concerning_child = { |
|
|
'developmental_milestones': 'G', |
|
|
'iq_dq': 55, |
|
|
'intellectual_disability': 'F70.0', |
|
|
'language_disorder': 'Y', |
|
|
'language_development': 'delay', |
|
|
'dysmorphism': 'NO', |
|
|
'behaviour_disorder': 'Y', |
|
|
'neurological_exam': 'N' |
|
|
} |
|
|
|
|
|
print("Input:") |
|
|
for k, v in concerning_child.items(): |
|
|
print(f" {k}: {v}") |
|
|
|
|
|
result = predictor.predict(concerning_child) |
|
|
print(f"\nResult: {result['prediction']}") |
|
|
print(f" Probability ASD: {result['probability_asd']:.2%}") |
|
|
print(f" Confidence: {result['confidence']:.2%}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("FEATURE REFERENCE") |
|
|
print("=" * 60) |
|
|
for name, info in FEATURES.items(): |
|
|
print(f"\n{name}:") |
|
|
print(f" {info['description']}") |
|
|
if isinstance(info['values'], dict): |
|
|
for k, v in info['values'].items(): |
|
|
print(f" '{k}' = {v}") |
|
|
else: |
|
|
print(f" {info['values']}") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|