toderian commited on
Commit
9c36fea
·
verified ·
1 Parent(s): 1562ea7

Add model.py

Browse files
Files changed (1) hide show
  1. model.py +135 -151
model.py CHANGED
@@ -1,98 +1,37 @@
1
  """
2
- Simplified Autism Spectrum Disorder (ASD) Detector Model
3
 
4
- 8-feature model capturing 84% of predictive power.
 
5
  """
6
 
7
  import torch
8
  import torch.nn as nn
9
- import pandas as pd
10
- import numpy as np
11
- from sklearn.preprocessing import StandardScaler, LabelEncoder
12
 
13
 
14
- # Original column names used in training
15
- SELECTED_FEATURES = [
16
- 'Developmental milestones- global delay (G), motor delay (M), cognitive delay (C)',
17
- 'IQ/DQ',
18
- 'ICD',
19
- 'Language disorder Y= present, N=absent',
20
- 'Language development: delay, normal=N, absent=A',
21
- 'Dysmorphysm y=present, no=absent',
22
- 'Behaviour disorder- agressivity, agitation, irascibility',
23
- 'Neurological Examination; N=normal, text = abnormal; free cell = examination not performed ???'
24
- ]
25
-
26
-
27
- class SimplePreprocessor:
28
- """Preprocessor for the 8 selected features."""
29
-
30
- def __init__(self):
31
- self.label_encoders = {}
32
- self.scaler = StandardScaler()
33
- self.numeric_cols = ['IQ/DQ']
34
- self.categorical_cols = [f for f in SELECTED_FEATURES if f != 'IQ/DQ']
35
-
36
- def fit(self, X):
37
- X = X.copy()
38
- X['IQ/DQ'] = pd.to_numeric(X['IQ/DQ'], errors='coerce')
39
-
40
- for col in self.categorical_cols:
41
- X[col] = X[col].fillna('_missing_').astype(str)
42
- all_values = list(X[col].unique()) + ['_missing_', '_unknown_']
43
- self.label_encoders[col] = LabelEncoder()
44
- self.label_encoders[col].fit(all_values)
45
-
46
- X_encoded = self._encode(X)
47
- self.scaler.fit(X_encoded)
48
- return self
49
-
50
- def _encode(self, X):
51
- X = X.copy()
52
- X['IQ/DQ'] = pd.to_numeric(X['IQ/DQ'], errors='coerce').fillna(70)
53
-
54
- for col in self.categorical_cols:
55
- X[col] = X[col].fillna('_missing_').astype(str)
56
- known_classes = set(self.label_encoders[col].classes_)
57
- X[col] = X[col].apply(lambda x: x if x in known_classes else '_unknown_')
58
- X[col] = self.label_encoders[col].transform(X[col])
59
-
60
- return X[SELECTED_FEATURES].values
61
-
62
- def transform(self, X):
63
- X_encoded = self._encode(X)
64
- return self.scaler.transform(X_encoded)
65
-
66
- def fit_transform(self, X):
67
- self.fit(X)
68
- return self.transform(X)
69
-
70
-
71
- class SimplifiedASDDetector(nn.Module):
72
  """
73
- Simplified neural network for ASD detection using 8 key features.
74
-
75
- Features:
76
- 1. developmental_milestones - N/G/M/C
77
- 2. iq_dq - numeric (0-150)
78
- 3. intellectual_disability - N/F70.0/F71/F72
79
- 4. language_disorder - N/Y
80
- 5. language_development - N/delay/A
81
- 6. dysmorphism - NO/Y
82
- 7. behaviour_disorder - N/Y
83
- 8. neurological_exam - N/abnormal text
84
-
85
- Args:
86
- input_size (int): Number of input features (8 after encoding)
87
- hidden_sizes (list): Hidden layer sizes. Default: [32, 16]
88
- dropout_rate (float): Dropout probability. Default: 0.3
89
  """
90
 
91
- def __init__(self, input_size, hidden_sizes=None, dropout_rate=0.3):
92
- super(SimplifiedASDDetector, self).__init__()
93
 
94
  if hidden_sizes is None:
95
- hidden_sizes = [32, 16]
96
 
97
  layers = []
98
  prev_size = input_size
@@ -100,94 +39,139 @@ class SimplifiedASDDetector(nn.Module):
100
  for hidden_size in hidden_sizes:
101
  layers.extend([
102
  nn.Linear(prev_size, hidden_size),
103
- nn.BatchNorm1d(hidden_size),
104
  nn.ReLU(),
105
- nn.Dropout(dropout_rate)
106
  ])
107
  prev_size = hidden_size
108
 
109
- layers.append(nn.Linear(prev_size, 1))
110
- layers.append(nn.Sigmoid())
111
 
112
- self.network = nn.Sequential(*layers)
113
  self.input_size = input_size
114
  self.hidden_sizes = hidden_sizes
115
- self.dropout_rate = dropout_rate
 
116
 
117
  def forward(self, x):
118
- """Forward pass returning probability of ASD."""
119
- return self.network(x)
120
-
121
- def predict(self, x, threshold=0.5):
122
- """Binary prediction (0=Healthy, 1=ASD)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  self.eval()
124
  with torch.no_grad():
125
- probs = self.forward(x)
126
- return (probs > threshold).int()
 
 
 
 
 
 
 
127
 
128
 
129
  def load_model(model_path, device='cpu'):
130
- """Load trained model from .pth file."""
131
- checkpoint = torch.load(model_path, map_location=device, weights_only=False)
 
 
132
 
133
- model = SimplifiedASDDetector(
134
- input_size=checkpoint['input_size'],
135
- hidden_sizes=checkpoint['hidden_sizes'],
136
- dropout_rate=checkpoint['dropout_rate']
137
- )
138
 
139
- model.load_state_dict(checkpoint['model_state_dict'])
140
- model.to(device)
141
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
- return model
 
 
144
 
 
 
 
 
 
 
 
 
145
 
146
- # Feature information for reference
147
- FEATURES = {
148
- 'developmental_milestones': {
149
- 'description': 'Developmental milestones status',
150
- 'values': {'N': 'Normal', 'G': 'Global delay', 'M': 'Motor delay', 'C': 'Cognitive delay'}
151
- },
152
- 'iq_dq': {
153
- 'description': 'IQ or Developmental Quotient',
154
- 'values': 'numeric (typically 20-150, average ~100)'
155
- },
156
- 'intellectual_disability': {
157
- 'description': 'ICD code for intellectual disability',
158
- 'values': {'N': 'None', 'F70.0': 'Mild (IQ 50-69)', 'F71': 'Moderate (IQ 35-49)', 'F72': 'Severe (IQ 20-34)'}
159
- },
160
- 'language_disorder': {
161
- 'description': 'Presence of language disorder',
162
- 'values': {'N': 'No', 'Y': 'Yes'}
163
- },
164
- 'language_development': {
165
- 'description': 'Language development status',
166
- 'values': {'N': 'Normal', 'delay': 'Delayed', 'A': 'Absent'}
167
- },
168
- 'dysmorphism': {
169
- 'description': 'Physical dysmorphic features',
170
- 'values': {'NO': 'Absent', 'Y': 'Present'}
171
- },
172
- 'behaviour_disorder': {
173
- 'description': 'Behavioral issues (aggression, agitation)',
174
- 'values': {'N': 'No', 'Y': 'Yes'}
175
- },
176
- 'neurological_exam': {
177
- 'description': 'Neurological examination result',
178
- 'values': {'N': 'Normal', 'other': 'Abnormal (free text description)'}
179
- }
180
- }
181
 
182
 
183
  if __name__ == '__main__':
184
- print("Simplified ASD Detector - 8 Features")
185
- print("=" * 50)
186
- print("\nRequired inputs:")
187
- for i, (name, info) in enumerate(FEATURES.items(), 1):
188
- print(f"\n{i}. {name}")
189
- print(f" Description: {info['description']}")
190
- if isinstance(info['values'], dict):
191
- print(f" Values: {', '.join(f'{k}={v}' for k, v in info['values'].items())}")
192
- else:
193
- print(f" Values: {info['values']}")
 
 
 
 
1
  """
2
+ Autism Detector Model
3
 
4
+ A feedforward neural network for ASD risk classification
5
+ from structured clinical data.
6
  """
7
 
8
  import torch
9
  import torch.nn as nn
 
 
 
10
 
11
 
12
+ class AutismDetector(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  """
14
+ Binary classifier for autism spectrum disorder screening.
15
+
16
+ Input: 8 preprocessed clinical features
17
+ Output: 2 logits (Healthy, ASD)
18
+
19
+ Features (in order):
20
+ 1. developmental_milestones - N/G/M/C (encoded 0-3)
21
+ 2. iq_dq - numeric, normalized 0-1
22
+ 3. intellectual_disability - N/F70.0/F71/F72 (encoded 0-3)
23
+ 4. language_disorder - N/Y (encoded 0-1)
24
+ 5. language_development - N/delay/A (encoded 0-2)
25
+ 6. dysmorphism - NO/Y (encoded 0-1)
26
+ 7. behaviour_disorder - N/Y (encoded 0-1)
27
+ 8. neurological_exam - N/abnormal (encoded 0-1)
 
 
28
  """
29
 
30
+ def __init__(self, input_size=8, hidden_sizes=None, num_classes=2, dropout=0.3):
31
+ super().__init__()
32
 
33
  if hidden_sizes is None:
34
+ hidden_sizes = [64, 32]
35
 
36
  layers = []
37
  prev_size = input_size
 
39
  for hidden_size in hidden_sizes:
40
  layers.extend([
41
  nn.Linear(prev_size, hidden_size),
 
42
  nn.ReLU(),
43
+ nn.Dropout(dropout),
44
  ])
45
  prev_size = hidden_size
46
 
47
+ layers.append(nn.Linear(prev_size, num_classes))
48
+ self.classifier = nn.Sequential(*layers)
49
 
50
+ # Store config
51
  self.input_size = input_size
52
  self.hidden_sizes = hidden_sizes
53
+ self.num_classes = num_classes
54
+ self.dropout = dropout
55
 
56
  def forward(self, x):
57
+ """
58
+ Forward pass.
59
+
60
+ Parameters
61
+ ----------
62
+ x : torch.Tensor
63
+ Input tensor of shape (batch_size, 8)
64
+
65
+ Returns
66
+ -------
67
+ torch.Tensor
68
+ Output logits of shape (batch_size, num_classes)
69
+ """
70
+ return self.classifier(x)
71
+
72
+ def predict(self, x):
73
+ """
74
+ Make predictions with probabilities.
75
+
76
+ Parameters
77
+ ----------
78
+ x : torch.Tensor
79
+ Input tensor of shape (batch_size, 8)
80
+
81
+ Returns
82
+ -------
83
+ dict with 'prediction', 'probability', 'logits'
84
+ """
85
  self.eval()
86
  with torch.no_grad():
87
+ logits = self.forward(x)
88
+ probs = torch.softmax(logits, dim=-1)
89
+ pred_class = torch.argmax(probs, dim=-1)
90
+
91
+ return {
92
+ 'prediction': pred_class,
93
+ 'probabilities': probs,
94
+ 'logits': logits
95
+ }
96
 
97
 
98
  def load_model(model_path, device='cpu'):
99
+ """Load TorchScript model."""
100
+ model = torch.jit.load(model_path, map_location=device)
101
+ model.eval()
102
+ return model
103
 
 
 
 
 
 
104
 
105
+ def preprocess(data, config):
106
+ """
107
+ Preprocess input data using JSON config.
108
+
109
+ Parameters
110
+ ----------
111
+ data : dict
112
+ Input features as dictionary
113
+ config : dict
114
+ Preprocessor configuration from preprocessor_config.json
115
+
116
+ Returns
117
+ -------
118
+ torch.Tensor
119
+ Preprocessed features tensor of shape (1, 8)
120
+ """
121
+ features = []
122
 
123
+ for feature_name in config["feature_order"]:
124
+ if feature_name in config["categorical_features"]:
125
+ feat_config = config["categorical_features"][feature_name]
126
 
127
+ if feat_config["type"] == "text_binary":
128
+ # For neurological_exam: N -> 0, anything else -> 1
129
+ raw_value = str(data[feature_name]).strip().upper()
130
+ value = 0 if raw_value == feat_config["normal_value"] else 1
131
+ else:
132
+ # Standard categorical/binary mapping
133
+ raw_value = data[feature_name]
134
+ value = feat_config["mapping"].get(raw_value, 0)
135
 
136
+ elif feature_name in config["numeric_features"]:
137
+ feat_config = config["numeric_features"][feature_name]
138
+ raw = float(data[feature_name])
139
+ # Min-max normalization
140
+ value = (raw - feat_config["min"]) / (feat_config["max"] - feat_config["min"])
141
+ value = max(0, min(1, value)) # Clamp to [0, 1]
142
+
143
+ features.append(value)
144
+
145
+ return torch.tensor([features], dtype=torch.float32)
146
+
147
+
148
+ def get_risk_level(probability):
149
+ """
150
+ Get risk level from ASD probability.
151
+
152
+ Returns
153
+ -------
154
+ str: 'low', 'medium', or 'high'
155
+ """
156
+ if probability < 0.4:
157
+ return "low"
158
+ elif probability < 0.7:
159
+ return "medium"
160
+ else:
161
+ return "high"
 
 
 
 
 
 
 
 
 
162
 
163
 
164
  if __name__ == '__main__':
165
+ # Test model creation
166
+ model = AutismDetector()
167
+ print(f"Model architecture:\n{model}")
168
+
169
+ # Test forward pass
170
+ x = torch.randn(2, 8)
171
+ output = model(x)
172
+ print(f"\nInput shape: {x.shape}")
173
+ print(f"Output shape: {output.shape}")
174
+ print(f"Output (logits): {output}")
175
+
176
+ probs = torch.softmax(output, dim=-1)
177
+ print(f"Probabilities: {probs}")