autism-detector / inference.py
toderian's picture
Add inference.py
624473d verified
"""
Inference script for Simplified ASD Detector (8 features)
Example usage for making predictions with the trained model.
"""
import sys
import json
import torch
import joblib
import pandas as pd
from pathlib import Path
from model import SimplifiedASDDetector, load_model, FEATURES, SimplePreprocessor
# Fix for unpickling preprocessor saved from different module
sys.modules['__main__'].SimplePreprocessor = SimplePreprocessor
# Original column names (as in the training data)
ORIGINAL_COLUMN_NAMES = [
'Developmental milestones- global delay (G), motor delay (M), cognitive delay (C)',
'IQ/DQ',
'ICD',
'Language disorder Y= present, N=absent',
'Language development: delay, normal=N, absent=A',
'Dysmorphysm y=present, no=absent',
'Behaviour disorder- agressivity, agitation, irascibility',
'Neurological Examination; N=normal, text = abnormal; free cell = examination not performed ???'
]
# Simplified names for user-friendly input
SIMPLE_NAMES = [
'developmental_milestones',
'iq_dq',
'intellectual_disability',
'language_disorder',
'language_development',
'dysmorphism',
'behaviour_disorder',
'neurological_exam'
]
class ASDPredictor:
"""
Easy-to-use predictor for ASD detection.
Example:
>>> predictor = ASDPredictor('.')
>>> result = predictor.predict({
... 'developmental_milestones': 'N',
... 'iq_dq': 100,
... 'intellectual_disability': 'N',
... 'language_disorder': 'N',
... 'language_development': 'N',
... 'dysmorphism': 'NO',
... 'behaviour_disorder': 'N',
... 'neurological_exam': 'N'
... })
>>> print(result['prediction']) # 'Healthy' or 'ASD'
"""
def __init__(self, model_dir='.', device='cpu'):
model_dir = Path(model_dir)
self.device = device
# Load model
self.model = load_model(model_dir / 'autism_detector.pth', device)
# Load preprocessor
self.preprocessor = joblib.load(model_dir / 'preprocessor.joblib')
# Load config
with open(model_dir / 'config.json', 'r') as f:
self.config = json.load(f)
def _convert_simple_to_original(self, data):
"""Convert simplified feature names to original column names."""
if isinstance(data, dict):
converted = {}
for simple, original in zip(SIMPLE_NAMES, ORIGINAL_COLUMN_NAMES):
if simple in data:
converted[original] = data[simple]
elif original in data:
converted[original] = data[original]
return converted
return data
def predict(self, data):
"""
Make prediction on patient data.
Args:
data (dict): Patient features using simplified names:
- developmental_milestones: N/G/M/C
- iq_dq: numeric (e.g., 100)
- intellectual_disability: N/F70.0/F71/F72
- language_disorder: N/Y
- language_development: N/delay/A
- dysmorphism: NO/Y
- behaviour_disorder: N/Y
- neurological_exam: N or abnormal description
Returns:
dict: {
'prediction': 'Healthy' or 'ASD',
'probability_asd': float,
'probability_healthy': float,
'confidence': float
}
"""
# Convert to original column names
converted = self._convert_simple_to_original(data)
df = pd.DataFrame([converted])
# Preprocess
X = self.preprocessor.transform(df)
X_tensor = torch.FloatTensor(X).to(self.device)
# Predict
self.model.eval()
with torch.no_grad():
prob_asd = self.model(X_tensor).cpu().item()
return {
'prediction': 'ASD' if prob_asd > 0.5 else 'Healthy',
'label': 1 if prob_asd > 0.5 else 0,
'probability_asd': prob_asd,
'probability_healthy': 1 - prob_asd,
'confidence': max(prob_asd, 1 - prob_asd)
}
@staticmethod
def get_feature_info():
"""Return information about required features."""
return FEATURES
def main():
"""Example usage."""
print("=" * 60)
print("ASD Detector - Simplified 8-Feature Model")
print("=" * 60)
predictor = ASDPredictor('.')
# Example 1: Healthy child profile
print("\n--- Example 1: Healthy Child ---")
healthy_child = {
'developmental_milestones': 'N', # Normal
'iq_dq': 105, # Normal IQ
'intellectual_disability': 'N', # None
'language_disorder': 'N', # No
'language_development': 'N', # Normal
'dysmorphism': 'NO', # Absent
'behaviour_disorder': 'N', # No
'neurological_exam': 'N' # Normal
}
print("Input:")
for k, v in healthy_child.items():
print(f" {k}: {v}")
result = predictor.predict(healthy_child)
print(f"\nResult: {result['prediction']}")
print(f" Probability ASD: {result['probability_asd']:.2%}")
print(f" Confidence: {result['confidence']:.2%}")
# Example 2: Child with developmental concerns
print("\n--- Example 2: Child with Developmental Concerns ---")
concerning_child = {
'developmental_milestones': 'G', # Global delay
'iq_dq': 55, # Below average
'intellectual_disability': 'F70.0', # Mild
'language_disorder': 'Y', # Yes
'language_development': 'delay', # Delayed
'dysmorphism': 'NO', # Absent
'behaviour_disorder': 'Y', # Yes
'neurological_exam': 'N' # Normal
}
print("Input:")
for k, v in concerning_child.items():
print(f" {k}: {v}")
result = predictor.predict(concerning_child)
print(f"\nResult: {result['prediction']}")
print(f" Probability ASD: {result['probability_asd']:.2%}")
print(f" Confidence: {result['confidence']:.2%}")
# Print feature reference
print("\n" + "=" * 60)
print("FEATURE REFERENCE")
print("=" * 60)
for name, info in FEATURES.items():
print(f"\n{name}:")
print(f" {info['description']}")
if isinstance(info['values'], dict):
for k, v in info['values'].items():
print(f" '{k}' = {v}")
else:
print(f" {info['values']}")
if __name__ == '__main__':
main()