snikhilesh commited on
Commit
504e9f1
·
verified ·
1 Parent(s): 614f2ea

Deploy generate_test_data.py to backend/ directory

Browse files
Files changed (1) hide show
  1. backend/generate_test_data.py +300 -0
backend/generate_test_data.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Synthetic Medical Test Data Generator
3
+ Creates realistic medical test cases for validation without real PHI
4
+ """
5
+
6
+ import json
7
+ import random
8
+ from datetime import datetime, timedelta
9
+ from typing import Dict, List, Any
10
+
11
+ class MedicalTestDataGenerator:
12
+ """Generate synthetic medical test data for validation"""
13
+
14
+ def __init__(self, seed=42):
15
+ random.seed(seed)
16
+
17
+ def generate_ecg_test_case(self, case_id: int, pathology: str) -> Dict[str, Any]:
18
+ """Generate a synthetic ECG test case"""
19
+
20
+ # Base parameters
21
+ base_hr = {
22
+ "normal": (60, 100),
23
+ "atrial_fibrillation": (80, 150),
24
+ "ventricular_tachycardia": (150, 250),
25
+ "heart_block": (30, 60),
26
+ "st_elevation": (60, 100),
27
+ "st_depression": (60, 100),
28
+ "qt_prolongation": (60, 90),
29
+ "bundle_branch_block": (60, 100)
30
+ }
31
+
32
+ hr_range = base_hr.get(pathology, (60, 100))
33
+ heart_rate = random.randint(hr_range[0], hr_range[1])
34
+
35
+ # Generate measurements
36
+ pr_interval = random.randint(120, 200) if pathology != "heart_block" else random.randint(200, 350)
37
+ qrs_duration = random.randint(80, 100) if pathology != "bundle_branch_block" else random.randint(120, 160)
38
+ qt_interval = random.randint(350, 450) if pathology != "qt_prolongation" else random.randint(450, 550)
39
+ qtc = qt_interval / (60/heart_rate)**0.5
40
+
41
+ return {
42
+ "case_id": f"ECG_{case_id:04d}",
43
+ "modality": "ECG",
44
+ "patient_age": random.randint(30, 80),
45
+ "patient_sex": random.choice(["M", "F"]),
46
+ "pathology": pathology,
47
+ "measurements": {
48
+ "heart_rate": heart_rate,
49
+ "pr_interval_ms": pr_interval,
50
+ "qrs_duration_ms": qrs_duration,
51
+ "qt_interval_ms": qt_interval,
52
+ "qtc_ms": round(qtc, 1),
53
+ "axis": random.choice(["normal", "left", "right"])
54
+ },
55
+ "ground_truth": {
56
+ "diagnosis": pathology,
57
+ "severity": random.choice(["mild", "moderate", "severe"]),
58
+ "clinical_significance": self._get_clinical_significance(pathology),
59
+ "requires_immediate_action": pathology in ["ventricular_tachycardia", "st_elevation"]
60
+ },
61
+ "confidence_expected": self._get_expected_confidence(pathology),
62
+ "review_required": pathology in ["heart_block", "qt_prolongation"]
63
+ }
64
+
65
+ def generate_radiology_test_case(self, case_id: int, pathology: str, modality: str) -> Dict[str, Any]:
66
+ """Generate a synthetic radiology test case"""
67
+
68
+ findings = {
69
+ "normal": "No acute findings",
70
+ "pneumonia": "Focal consolidation in right lower lobe",
71
+ "fracture": "Transverse fracture of distal radius",
72
+ "tumor": "3.2 cm mass in left upper lobe",
73
+ "organomegaly": "Hepatomegaly with liver span 18 cm"
74
+ }
75
+
76
+ return {
77
+ "case_id": f"RAD_{case_id:04d}",
78
+ "modality": modality,
79
+ "imaging_type": random.choice(["Chest X-ray", "CT Chest", "MRI Brain", "Ultrasound Abdomen"]),
80
+ "patient_age": random.randint(20, 85),
81
+ "patient_sex": random.choice(["M", "F"]),
82
+ "pathology": pathology,
83
+ "findings": findings.get(pathology, "Unknown findings"),
84
+ "ground_truth": {
85
+ "primary_diagnosis": pathology,
86
+ "anatomical_location": self._get_anatomical_location(pathology),
87
+ "severity": random.choice(["mild", "moderate", "severe"]),
88
+ "clinical_significance": self._get_clinical_significance(pathology),
89
+ "requires_follow_up": pathology != "normal"
90
+ },
91
+ "confidence_expected": self._get_expected_confidence(pathology),
92
+ "review_required": pathology in ["tumor", "fracture"]
93
+ }
94
+
95
+ def _get_clinical_significance(self, pathology: str) -> str:
96
+ significance_map = {
97
+ "normal": "None",
98
+ "atrial_fibrillation": "High - stroke risk",
99
+ "ventricular_tachycardia": "Critical - life-threatening",
100
+ "heart_block": "High - may require pacemaker",
101
+ "st_elevation": "Critical - acute MI",
102
+ "st_depression": "High - ischemia",
103
+ "qt_prolongation": "Moderate - arrhythmia risk",
104
+ "bundle_branch_block": "Moderate - conduction disorder",
105
+ "pneumonia": "High - infectious process",
106
+ "fracture": "Moderate - structural injury",
107
+ "tumor": "High - potential malignancy",
108
+ "organomegaly": "Moderate - systemic disease"
109
+ }
110
+ return significance_map.get(pathology, "Unknown")
111
+
112
+ def _get_anatomical_location(self, pathology: str) -> str:
113
+ location_map = {
114
+ "pneumonia": "Right lower lobe",
115
+ "fracture": "Distal radius",
116
+ "tumor": "Left upper lobe",
117
+ "organomegaly": "Liver"
118
+ }
119
+ return location_map.get(pathology, "N/A")
120
+
121
+ def _get_expected_confidence(self, pathology: str) -> float:
122
+ """Expected confidence score for validation"""
123
+ # High confidence cases
124
+ if pathology in ["normal", "st_elevation", "ventricular_tachycardia", "fracture"]:
125
+ return random.uniform(0.85, 0.95)
126
+ # Medium confidence cases
127
+ elif pathology in ["qt_prolongation", "heart_block", "pneumonia", "tumor"]:
128
+ return random.uniform(0.65, 0.85)
129
+ # Lower confidence cases
130
+ else:
131
+ return random.uniform(0.50, 0.70)
132
+
133
+ def generate_test_dataset(self, num_ecg=500, num_radiology=200) -> Dict[str, List[Dict]]:
134
+ """Generate complete test dataset"""
135
+
136
+ print(f"Generating synthetic medical test dataset...")
137
+ print(f"ECG cases: {num_ecg}")
138
+ print(f"Radiology cases: {num_radiology}")
139
+
140
+ # ECG pathology distribution
141
+ ecg_pathologies = [
142
+ ("normal", int(num_ecg * 0.20)), # 20% normal
143
+ ("atrial_fibrillation", int(num_ecg * 0.16)),
144
+ ("ventricular_tachycardia", int(num_ecg * 0.12)),
145
+ ("heart_block", int(num_ecg * 0.10)),
146
+ ("st_elevation", int(num_ecg * 0.14)),
147
+ ("st_depression", int(num_ecg * 0.12)),
148
+ ("qt_prolongation", int(num_ecg * 0.08)),
149
+ ("bundle_branch_block", int(num_ecg * 0.08))
150
+ ]
151
+
152
+ ecg_cases = []
153
+ case_id = 1
154
+ for pathology, count in ecg_pathologies:
155
+ for _ in range(count):
156
+ ecg_cases.append(self.generate_ecg_test_case(case_id, pathology))
157
+ case_id += 1
158
+
159
+ # Radiology pathology distribution
160
+ rad_pathologies = [
161
+ ("normal", int(num_radiology * 0.25)), # 25% normal
162
+ ("pneumonia", int(num_radiology * 0.30)),
163
+ ("fracture", int(num_radiology * 0.20)),
164
+ ("tumor", int(num_radiology * 0.15)),
165
+ ("organomegaly", int(num_radiology * 0.10))
166
+ ]
167
+
168
+ rad_cases = []
169
+ case_id = 1
170
+ for pathology, count in rad_pathologies:
171
+ for _ in range(count):
172
+ modality = random.choice(["Chest X-ray", "CT", "MRI", "Ultrasound"])
173
+ rad_cases.append(self.generate_radiology_test_case(case_id, pathology, modality))
174
+ case_id += 1
175
+
176
+ print(f"\nGenerated:")
177
+ print(f" ECG cases: {len(ecg_cases)}")
178
+ print(f" Radiology cases: {len(rad_cases)}")
179
+ print(f" Total: {len(ecg_cases) + len(rad_cases)}")
180
+
181
+ return {
182
+ "ecg_cases": ecg_cases,
183
+ "radiology_cases": rad_cases,
184
+ "metadata": {
185
+ "generated_date": datetime.now().isoformat(),
186
+ "total_cases": len(ecg_cases) + len(rad_cases),
187
+ "ecg_distribution": {p: c for p, c in ecg_pathologies},
188
+ "radiology_distribution": {p: c for p, c in rad_pathologies}
189
+ }
190
+ }
191
+
192
+ class ValidationMetricsCalculator:
193
+ """Calculate clinical validation metrics"""
194
+
195
+ def calculate_metrics(self, predictions: List[Dict], ground_truth: List[Dict]) -> Dict[str, Any]:
196
+ """Calculate sensitivity, specificity, F1, AUROC"""
197
+
198
+ # Match predictions with ground truth
199
+ tp = fp = tn = fn = 0
200
+
201
+ for pred, truth in zip(predictions, ground_truth):
202
+ pred_positive = pred.get("diagnosis") == truth.get("pathology")
203
+ truth_positive = truth.get("pathology") != "normal"
204
+
205
+ if pred_positive and truth_positive:
206
+ tp += 1
207
+ elif pred_positive and not truth_positive:
208
+ fp += 1
209
+ elif not pred_positive and not truth_positive:
210
+ tn += 1
211
+ elif not pred_positive and truth_positive:
212
+ fn += 1
213
+
214
+ # Calculate metrics
215
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
216
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
217
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
218
+ recall = sensitivity
219
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
220
+
221
+ return {
222
+ "confusion_matrix": {
223
+ "true_positives": tp,
224
+ "false_positives": fp,
225
+ "true_negatives": tn,
226
+ "false_negatives": fn
227
+ },
228
+ "metrics": {
229
+ "sensitivity": round(sensitivity, 4),
230
+ "specificity": round(specificity, 4),
231
+ "precision": round(precision, 4),
232
+ "recall": round(recall, 4),
233
+ "f1_score": round(f1_score, 4)
234
+ },
235
+ "total_cases": len(predictions)
236
+ }
237
+
238
+ def main():
239
+ """Generate test dataset and save to files"""
240
+
241
+ print("="*60)
242
+ print("SYNTHETIC MEDICAL TEST DATA GENERATION")
243
+ print("="*60)
244
+ print(f"Started: {datetime.now().isoformat()}\n")
245
+
246
+ generator = MedicalTestDataGenerator(seed=42)
247
+
248
+ # Generate full dataset
249
+ dataset = generator.generate_test_dataset(num_ecg=500, num_radiology=200)
250
+
251
+ # Save to files
252
+ output_dir = "/workspace/medical-ai-platform/test_data"
253
+ import os
254
+ os.makedirs(output_dir, exist_ok=True)
255
+
256
+ # Save complete dataset
257
+ with open(f"{output_dir}/complete_test_dataset.json", "w") as f:
258
+ json.dump(dataset, f, indent=2)
259
+ print(f"\nSaved complete dataset to: {output_dir}/complete_test_dataset.json")
260
+
261
+ # Save ECG cases separately
262
+ with open(f"{output_dir}/ecg_test_cases.json", "w") as f:
263
+ json.dump(dataset["ecg_cases"], f, indent=2)
264
+ print(f"Saved ECG cases to: {output_dir}/ecg_test_cases.json")
265
+
266
+ # Save radiology cases separately
267
+ with open(f"{output_dir}/radiology_test_cases.json", "w") as f:
268
+ json.dump(dataset["radiology_cases"], f, indent=2)
269
+ print(f"Saved radiology cases to: {output_dir}/radiology_test_cases.json")
270
+
271
+ # Generate summary statistics
272
+ summary = {
273
+ "total_cases": dataset["metadata"]["total_cases"],
274
+ "ecg_cases": len(dataset["ecg_cases"]),
275
+ "radiology_cases": len(dataset["radiology_cases"]),
276
+ "ecg_distribution": dataset["metadata"]["ecg_distribution"],
277
+ "radiology_distribution": dataset["metadata"]["radiology_distribution"],
278
+ "generated_date": dataset["metadata"]["generated_date"]
279
+ }
280
+
281
+ with open(f"{output_dir}/dataset_summary.json", "w") as f:
282
+ json.dump(summary, f, indent=2)
283
+ print(f"Saved summary to: {output_dir}/dataset_summary.json")
284
+
285
+ print("\n" + "="*60)
286
+ print("DATA GENERATION COMPLETE")
287
+ print("="*60)
288
+ print(f"\nDataset Statistics:")
289
+ print(f" Total Cases: {summary['total_cases']}")
290
+ print(f" ECG Cases: {summary['ecg_cases']}")
291
+ print(f" Radiology Cases: {summary['radiology_cases']}")
292
+ print(f"\nECG Pathology Distribution:")
293
+ for pathology, count in summary['ecg_distribution'].items():
294
+ print(f" {pathology}: {count} cases")
295
+ print(f"\nRadiology Pathology Distribution:")
296
+ for pathology, count in summary['radiology_distribution'].items():
297
+ print(f" {pathology}: {count} cases")
298
+
299
+ if __name__ == "__main__":
300
+ main()