toderian commited on
Commit
f61ae9f
·
verified ·
1 Parent(s): a517fc1

Add inference.py

Browse files
Files changed (1) hide show
  1. inference.py +189 -0
inference.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for Autism Spectrum Disorder (ASD) Detector
3
+
4
+ Example usage of the trained model for making predictions.
5
+ """
6
+
7
+ import json
8
+ import torch
9
+ import joblib
10
+ import pandas as pd
11
+ import numpy as np
12
+ from pathlib import Path
13
+ from model import AutismDetectorNet, load_model
14
+
15
+
16
+ class ASDPredictor:
17
+ """
18
+ Wrapper class for easy ASD prediction.
19
+
20
+ Example:
21
+ >>> predictor = ASDPredictor('.')
22
+ >>> result = predictor.predict(patient_data)
23
+ >>> print(result)
24
+ """
25
+
26
+ def __init__(self, model_dir='.', device='cpu'):
27
+ """
28
+ Initialize the predictor.
29
+
30
+ Args:
31
+ model_dir (str): Directory containing model files
32
+ device (str): Device for inference ('cpu' or 'cuda')
33
+ """
34
+ model_dir = Path(model_dir)
35
+ self.device = device
36
+
37
+ # Load model
38
+ self.model = load_model(model_dir / 'autism_detector.pth', device)
39
+
40
+ # Load preprocessor
41
+ self.preprocessor = joblib.load(model_dir / 'preprocessor.joblib')
42
+
43
+ # Load config
44
+ with open(model_dir / 'config.json', 'r') as f:
45
+ self.config = json.load(f)
46
+
47
+ # Load feature info
48
+ with open(model_dir / 'feature_info.json', 'r') as f:
49
+ self.feature_info = json.load(f)
50
+
51
+ self.feature_columns = self.config['feature_columns']
52
+
53
+ def predict(self, data, return_proba=False):
54
+ """
55
+ Make predictions on patient data.
56
+
57
+ Args:
58
+ data (pd.DataFrame or dict): Patient data with required features
59
+ return_proba (bool): If True, return probabilities instead of labels
60
+
61
+ Returns:
62
+ dict: Prediction results including label and probability
63
+ """
64
+ # Convert dict to DataFrame if necessary
65
+ if isinstance(data, dict):
66
+ data = pd.DataFrame([data])
67
+
68
+ # Ensure all required columns are present
69
+ missing_cols = set(self.feature_columns) - set(data.columns)
70
+ if missing_cols:
71
+ raise ValueError(f"Missing required columns: {missing_cols}")
72
+
73
+ # Select only required columns in correct order
74
+ data = data[self.feature_columns]
75
+
76
+ # Preprocess
77
+ X = self.preprocessor.transform(data)
78
+ X_tensor = torch.FloatTensor(X).to(self.device)
79
+
80
+ # Predict
81
+ self.model.eval()
82
+ with torch.no_grad():
83
+ probabilities = self.model(X_tensor).cpu().numpy().flatten()
84
+
85
+ # Format results
86
+ results = []
87
+ for prob in probabilities:
88
+ result = {
89
+ 'prediction': 'ASD' if prob > 0.5 else 'Healthy',
90
+ 'label': 1 if prob > 0.5 else 0,
91
+ 'probability_asd': float(prob),
92
+ 'probability_healthy': float(1 - prob),
93
+ 'confidence': float(max(prob, 1 - prob))
94
+ }
95
+ results.append(result)
96
+
97
+ return results if len(results) > 1 else results[0]
98
+
99
+ def get_required_features(self):
100
+ """Return list of required feature columns."""
101
+ return self.feature_columns.copy()
102
+
103
+ def get_feature_info(self):
104
+ """Return detailed feature information."""
105
+ return self.feature_info.copy()
106
+
107
+
108
+ def create_sample_patient():
109
+ """Create a sample patient record for testing."""
110
+ return {
111
+ 'Gender': 'M',
112
+ 'Age': 48,
113
+ 'Urban/Rural': 'U',
114
+ 'Pregnancy natural /IVF': 'N',
115
+ 'Single/twins': 'S',
116
+ 'Pregnancy evolution, N=normal, AN=abnormal': 'N',
117
+ 'Birth weeks': 39,
118
+ 'Type of birth': 'N',
119
+ 'Birth Weight g': 3200,
120
+ 'Length at birth ': 50,
121
+ 'Head circumference at birth ': 35,
122
+ 'APGAR score': 9,
123
+ 'Postnatal adaptation N=normal, AN=abnormal': 'N',
124
+ 'Developmental milestones- global delay (G), motor delay (M), cognitive delay (C)': 'N',
125
+ 'Other Chronic diseases': 'N',
126
+ 'Infections': 0,
127
+ 'allergies': 0,
128
+ 'Family history psychiatric (P) or Neuro disease (Ne), No=absent': 'NO',
129
+ 'Mother age (years)': 30,
130
+ 'Father age (years)': 32,
131
+ 'IQ/DQ': 100,
132
+ 'ICD': 'N',
133
+ 'Neurological Examination; N=normal, text = abnormal; free cell = examination not performed ???': 'N',
134
+ 'Weight kg': 15,
135
+ 'Height ': 100,
136
+ 'head circumf ': 48,
137
+ 'Dysmorphysm y=present, no=absent': 'NO',
138
+ 'malformations Y= present, N=absent': 'N',
139
+ 'Behaviour disorder- agressivity, agitation, irascibility': 'N',
140
+ 'Language development: delay, normal=N, absent=A': 'N',
141
+ 'Language disorder Y= present, N=absent': 'N',
142
+ 'EEG, N=normal, F=focal discharges, G=bilateral discharges': 'N',
143
+ 'MRI structural anomalies of the brain, N=absent, AN=present': 'N'
144
+ }
145
+
146
+
147
+ def main():
148
+ """Example usage of the ASD predictor."""
149
+ print("=" * 60)
150
+ print("ASD Detector - Inference Example")
151
+ print("=" * 60)
152
+
153
+ # Initialize predictor
154
+ predictor = ASDPredictor(model_dir='.')
155
+
156
+ # Create sample patient
157
+ print("\nSample Patient (Healthy profile):")
158
+ patient = create_sample_patient()
159
+ for key, value in list(patient.items())[:10]:
160
+ print(f" {key}: {value}")
161
+ print(" ...")
162
+
163
+ # Make prediction
164
+ print("\nPrediction:")
165
+ result = predictor.predict(patient)
166
+ print(f" Label: {result['prediction']}")
167
+ print(f" Probability (ASD): {result['probability_asd']:.4f}")
168
+ print(f" Probability (Healthy): {result['probability_healthy']:.4f}")
169
+ print(f" Confidence: {result['confidence']:.4f}")
170
+
171
+ # Test with ASD-like profile
172
+ print("\n" + "-" * 60)
173
+ print("\nSample Patient (ASD-like profile):")
174
+ patient_asd = create_sample_patient()
175
+ patient_asd['Developmental milestones- global delay (G), motor delay (M), cognitive delay (C)'] = 'G'
176
+ patient_asd['Language development: delay, normal=N, absent=A'] = 'delay'
177
+ patient_asd['IQ/DQ'] = 45
178
+ patient_asd['Behaviour disorder- agressivity, agitation, irascibility'] = 'Y'
179
+
180
+ result_asd = predictor.predict(patient_asd)
181
+ print(f" Changed features: Developmental=G, Language=delay, IQ=45, Behaviour=Y")
182
+ print(f"\nPrediction:")
183
+ print(f" Label: {result_asd['prediction']}")
184
+ print(f" Probability (ASD): {result_asd['probability_asd']:.4f}")
185
+ print(f" Confidence: {result_asd['confidence']:.4f}")
186
+
187
+
188
+ if __name__ == '__main__':
189
+ main()