toderian commited on
Commit
624473d
·
verified ·
1 Parent(s): 0d5e4d7

Add inference.py

Browse files
Files changed (1) hide show
  1. inference.py +151 -134
inference.py CHANGED
@@ -1,36 +1,66 @@
1
  """
2
- Inference script for Autism Spectrum Disorder (ASD) Detector
3
 
4
- Example usage of the trained model for making predictions.
5
  """
6
 
 
7
  import json
8
  import torch
9
  import joblib
10
  import pandas as pd
11
- import numpy as np
12
  from pathlib import Path
13
- from model import AutismDetectorNet, load_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  class ASDPredictor:
17
  """
18
- Wrapper class for easy ASD prediction.
19
 
20
  Example:
21
  >>> predictor = ASDPredictor('.')
22
- >>> result = predictor.predict(patient_data)
23
- >>> print(result)
 
 
 
 
 
 
 
 
 
24
  """
25
 
26
  def __init__(self, model_dir='.', device='cpu'):
27
- """
28
- Initialize the predictor.
29
-
30
- Args:
31
- model_dir (str): Directory containing model files
32
- device (str): Device for inference ('cpu' or 'cuda')
33
- """
34
  model_dir = Path(model_dir)
35
  self.device = device
36
 
@@ -44,145 +74,132 @@ class ASDPredictor:
44
  with open(model_dir / 'config.json', 'r') as f:
45
  self.config = json.load(f)
46
 
47
- # Load feature info
48
- with open(model_dir / 'feature_info.json', 'r') as f:
49
- self.feature_info = json.load(f)
50
-
51
- self.feature_columns = self.config['feature_columns']
52
-
53
- def predict(self, data, return_proba=False):
 
 
 
 
 
 
54
  """
55
- Make predictions on patient data.
56
 
57
  Args:
58
- data (pd.DataFrame or dict): Patient data with required features
59
- return_proba (bool): If True, return probabilities instead of labels
 
 
 
 
 
 
 
60
 
61
  Returns:
62
- dict: Prediction results including label and probability
 
 
 
 
 
63
  """
64
- # Convert dict to DataFrame if necessary
65
- if isinstance(data, dict):
66
- data = pd.DataFrame([data])
67
-
68
- # Ensure all required columns are present
69
- missing_cols = set(self.feature_columns) - set(data.columns)
70
- if missing_cols:
71
- raise ValueError(f"Missing required columns: {missing_cols}")
72
-
73
- # Select only required columns in correct order
74
- data = data[self.feature_columns]
75
 
76
  # Preprocess
77
- X = self.preprocessor.transform(data)
78
  X_tensor = torch.FloatTensor(X).to(self.device)
79
 
80
  # Predict
81
  self.model.eval()
82
  with torch.no_grad():
83
- probabilities = self.model(X_tensor).cpu().numpy().flatten()
84
-
85
- # Format results
86
- results = []
87
- for prob in probabilities:
88
- result = {
89
- 'prediction': 'ASD' if prob > 0.5 else 'Healthy',
90
- 'label': 1 if prob > 0.5 else 0,
91
- 'probability_asd': float(prob),
92
- 'probability_healthy': float(1 - prob),
93
- 'confidence': float(max(prob, 1 - prob))
94
- }
95
- results.append(result)
96
-
97
- return results if len(results) > 1 else results[0]
98
-
99
- def get_required_features(self):
100
- """Return list of required feature columns."""
101
- return self.feature_columns.copy()
102
-
103
- def get_feature_info(self):
104
- """Return detailed feature information."""
105
- return self.feature_info.copy()
106
-
107
-
108
- def create_sample_patient():
109
- """Create a sample patient record for testing."""
110
- return {
111
- 'Gender': 'M',
112
- 'Age': 48,
113
- 'Urban/Rural': 'U',
114
- 'Pregnancy natural /IVF': 'N',
115
- 'Single/twins': 'S',
116
- 'Pregnancy evolution, N=normal, AN=abnormal': 'N',
117
- 'Birth weeks': 39,
118
- 'Type of birth': 'N',
119
- 'Birth Weight g': 3200,
120
- 'Length at birth ': 50,
121
- 'Head circumference at birth ': 35,
122
- 'APGAR score': 9,
123
- 'Postnatal adaptation N=normal, AN=abnormal': 'N',
124
- 'Developmental milestones- global delay (G), motor delay (M), cognitive delay (C)': 'N',
125
- 'Other Chronic diseases': 'N',
126
- 'Infections': 0,
127
- 'allergies': 0,
128
- 'Family history psychiatric (P) or Neuro disease (Ne), No=absent': 'NO',
129
- 'Mother age (years)': 30,
130
- 'Father age (years)': 32,
131
- 'IQ/DQ': 100,
132
- 'ICD': 'N',
133
- 'Neurological Examination; N=normal, text = abnormal; free cell = examination not performed ???': 'N',
134
- 'Weight kg': 15,
135
- 'Height ': 100,
136
- 'head circumf ': 48,
137
- 'Dysmorphysm y=present, no=absent': 'NO',
138
- 'malformations Y= present, N=absent': 'N',
139
- 'Behaviour disorder- agressivity, agitation, irascibility': 'N',
140
- 'Language development: delay, normal=N, absent=A': 'N',
141
- 'Language disorder Y= present, N=absent': 'N',
142
- 'EEG, N=normal, F=focal discharges, G=bilateral discharges': 'N',
143
- 'MRI structural anomalies of the brain, N=absent, AN=present': 'N'
144
- }
145
 
146
 
147
  def main():
148
- """Example usage of the ASD predictor."""
149
  print("=" * 60)
150
- print("ASD Detector - Inference Example")
151
  print("=" * 60)
152
 
153
- # Initialize predictor
154
- predictor = ASDPredictor(model_dir='.')
155
-
156
- # Create sample patient
157
- print("\nSample Patient (Healthy profile):")
158
- patient = create_sample_patient()
159
- for key, value in list(patient.items())[:10]:
160
- print(f" {key}: {value}")
161
- print(" ...")
162
-
163
- # Make prediction
164
- print("\nPrediction:")
165
- result = predictor.predict(patient)
166
- print(f" Label: {result['prediction']}")
167
- print(f" Probability (ASD): {result['probability_asd']:.4f}")
168
- print(f" Probability (Healthy): {result['probability_healthy']:.4f}")
169
- print(f" Confidence: {result['confidence']:.4f}")
170
-
171
- # Test with ASD-like profile
172
- print("\n" + "-" * 60)
173
- print("\nSample Patient (ASD-like profile):")
174
- patient_asd = create_sample_patient()
175
- patient_asd['Developmental milestones- global delay (G), motor delay (M), cognitive delay (C)'] = 'G'
176
- patient_asd['Language development: delay, normal=N, absent=A'] = 'delay'
177
- patient_asd['IQ/DQ'] = 45
178
- patient_asd['Behaviour disorder- agressivity, agitation, irascibility'] = 'Y'
179
-
180
- result_asd = predictor.predict(patient_asd)
181
- print(f" Changed features: Developmental=G, Language=delay, IQ=45, Behaviour=Y")
182
- print(f"\nPrediction:")
183
- print(f" Label: {result_asd['prediction']}")
184
- print(f" Probability (ASD): {result_asd['probability_asd']:.4f}")
185
- print(f" Confidence: {result_asd['confidence']:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
 
188
  if __name__ == '__main__':
 
1
  """
2
+ Inference script for Simplified ASD Detector (8 features)
3
 
4
+ Example usage for making predictions with the trained model.
5
  """
6
 
7
+ import sys
8
  import json
9
  import torch
10
  import joblib
11
  import pandas as pd
 
12
  from pathlib import Path
13
+ from model import SimplifiedASDDetector, load_model, FEATURES, SimplePreprocessor
14
+
15
+ # Fix for unpickling preprocessor saved from different module
16
+ sys.modules['__main__'].SimplePreprocessor = SimplePreprocessor
17
+
18
+
19
+ # Original column names (as in the training data)
20
+ ORIGINAL_COLUMN_NAMES = [
21
+ 'Developmental milestones- global delay (G), motor delay (M), cognitive delay (C)',
22
+ 'IQ/DQ',
23
+ 'ICD',
24
+ 'Language disorder Y= present, N=absent',
25
+ 'Language development: delay, normal=N, absent=A',
26
+ 'Dysmorphysm y=present, no=absent',
27
+ 'Behaviour disorder- agressivity, agitation, irascibility',
28
+ 'Neurological Examination; N=normal, text = abnormal; free cell = examination not performed ???'
29
+ ]
30
+
31
+ # Simplified names for user-friendly input
32
+ SIMPLE_NAMES = [
33
+ 'developmental_milestones',
34
+ 'iq_dq',
35
+ 'intellectual_disability',
36
+ 'language_disorder',
37
+ 'language_development',
38
+ 'dysmorphism',
39
+ 'behaviour_disorder',
40
+ 'neurological_exam'
41
+ ]
42
 
43
 
44
  class ASDPredictor:
45
  """
46
+ Easy-to-use predictor for ASD detection.
47
 
48
  Example:
49
  >>> predictor = ASDPredictor('.')
50
+ >>> result = predictor.predict({
51
+ ... 'developmental_milestones': 'N',
52
+ ... 'iq_dq': 100,
53
+ ... 'intellectual_disability': 'N',
54
+ ... 'language_disorder': 'N',
55
+ ... 'language_development': 'N',
56
+ ... 'dysmorphism': 'NO',
57
+ ... 'behaviour_disorder': 'N',
58
+ ... 'neurological_exam': 'N'
59
+ ... })
60
+ >>> print(result['prediction']) # 'Healthy' or 'ASD'
61
  """
62
 
63
  def __init__(self, model_dir='.', device='cpu'):
 
 
 
 
 
 
 
64
  model_dir = Path(model_dir)
65
  self.device = device
66
 
 
74
  with open(model_dir / 'config.json', 'r') as f:
75
  self.config = json.load(f)
76
 
77
+ def _convert_simple_to_original(self, data):
78
+ """Convert simplified feature names to original column names."""
79
+ if isinstance(data, dict):
80
+ converted = {}
81
+ for simple, original in zip(SIMPLE_NAMES, ORIGINAL_COLUMN_NAMES):
82
+ if simple in data:
83
+ converted[original] = data[simple]
84
+ elif original in data:
85
+ converted[original] = data[original]
86
+ return converted
87
+ return data
88
+
89
+ def predict(self, data):
90
  """
91
+ Make prediction on patient data.
92
 
93
  Args:
94
+ data (dict): Patient features using simplified names:
95
+ - developmental_milestones: N/G/M/C
96
+ - iq_dq: numeric (e.g., 100)
97
+ - intellectual_disability: N/F70.0/F71/F72
98
+ - language_disorder: N/Y
99
+ - language_development: N/delay/A
100
+ - dysmorphism: NO/Y
101
+ - behaviour_disorder: N/Y
102
+ - neurological_exam: N or abnormal description
103
 
104
  Returns:
105
+ dict: {
106
+ 'prediction': 'Healthy' or 'ASD',
107
+ 'probability_asd': float,
108
+ 'probability_healthy': float,
109
+ 'confidence': float
110
+ }
111
  """
112
+ # Convert to original column names
113
+ converted = self._convert_simple_to_original(data)
114
+ df = pd.DataFrame([converted])
 
 
 
 
 
 
 
 
115
 
116
  # Preprocess
117
+ X = self.preprocessor.transform(df)
118
  X_tensor = torch.FloatTensor(X).to(self.device)
119
 
120
  # Predict
121
  self.model.eval()
122
  with torch.no_grad():
123
+ prob_asd = self.model(X_tensor).cpu().item()
124
+
125
+ return {
126
+ 'prediction': 'ASD' if prob_asd > 0.5 else 'Healthy',
127
+ 'label': 1 if prob_asd > 0.5 else 0,
128
+ 'probability_asd': prob_asd,
129
+ 'probability_healthy': 1 - prob_asd,
130
+ 'confidence': max(prob_asd, 1 - prob_asd)
131
+ }
132
+
133
+ @staticmethod
134
+ def get_feature_info():
135
+ """Return information about required features."""
136
+ return FEATURES
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  def main():
140
+ """Example usage."""
141
  print("=" * 60)
142
+ print("ASD Detector - Simplified 8-Feature Model")
143
  print("=" * 60)
144
 
145
+ predictor = ASDPredictor('.')
146
+
147
+ # Example 1: Healthy child profile
148
+ print("\n--- Example 1: Healthy Child ---")
149
+ healthy_child = {
150
+ 'developmental_milestones': 'N', # Normal
151
+ 'iq_dq': 105, # Normal IQ
152
+ 'intellectual_disability': 'N', # None
153
+ 'language_disorder': 'N', # No
154
+ 'language_development': 'N', # Normal
155
+ 'dysmorphism': 'NO', # Absent
156
+ 'behaviour_disorder': 'N', # No
157
+ 'neurological_exam': 'N' # Normal
158
+ }
159
+
160
+ print("Input:")
161
+ for k, v in healthy_child.items():
162
+ print(f" {k}: {v}")
163
+
164
+ result = predictor.predict(healthy_child)
165
+ print(f"\nResult: {result['prediction']}")
166
+ print(f" Probability ASD: {result['probability_asd']:.2%}")
167
+ print(f" Confidence: {result['confidence']:.2%}")
168
+
169
+ # Example 2: Child with developmental concerns
170
+ print("\n--- Example 2: Child with Developmental Concerns ---")
171
+ concerning_child = {
172
+ 'developmental_milestones': 'G', # Global delay
173
+ 'iq_dq': 55, # Below average
174
+ 'intellectual_disability': 'F70.0', # Mild
175
+ 'language_disorder': 'Y', # Yes
176
+ 'language_development': 'delay', # Delayed
177
+ 'dysmorphism': 'NO', # Absent
178
+ 'behaviour_disorder': 'Y', # Yes
179
+ 'neurological_exam': 'N' # Normal
180
+ }
181
+
182
+ print("Input:")
183
+ for k, v in concerning_child.items():
184
+ print(f" {k}: {v}")
185
+
186
+ result = predictor.predict(concerning_child)
187
+ print(f"\nResult: {result['prediction']}")
188
+ print(f" Probability ASD: {result['probability_asd']:.2%}")
189
+ print(f" Confidence: {result['confidence']:.2%}")
190
+
191
+ # Print feature reference
192
+ print("\n" + "=" * 60)
193
+ print("FEATURE REFERENCE")
194
+ print("=" * 60)
195
+ for name, info in FEATURES.items():
196
+ print(f"\n{name}:")
197
+ print(f" {info['description']}")
198
+ if isinstance(info['values'], dict):
199
+ for k, v in info['values'].items():
200
+ print(f" '{k}' = {v}")
201
+ else:
202
+ print(f" {info['values']}")
203
 
204
 
205
  if __name__ == '__main__':