monajm36 commited on
Commit
9af99d7
·
unverified ·
1 Parent(s): a5114c2

Update __init__.py

Browse files
Files changed (1) hide show
  1. src/__init__.py +107 -14
src/__init__.py CHANGED
@@ -1,28 +1,52 @@
1
  """
2
- NLP OHCA Classifier
3
-
4
  A BERT-based classifier for detecting Out-of-Hospital Cardiac Arrest (OHCA)
5
- cases in medical discharge notes.
6
 
7
- This package contains two main modules:
 
 
 
 
 
8
 
9
- 1. ohca_training_pipeline: Complete training pipeline from annotation to model training
10
- 2. ohca_inference: Apply pre-trained models to new datasets
 
11
  """
12
 
13
- # Training pipeline imports
14
  from .ohca_training_pipeline import (
 
 
 
 
 
 
 
 
 
15
  create_training_sample,
16
  prepare_training_data,
17
  train_ohca_model,
18
  evaluate_model,
19
  complete_training_pipeline,
20
  complete_annotation_and_train,
 
 
21
  OHCATrainingDataset
22
  )
23
 
24
- # Inference imports
25
  from .ohca_inference import (
 
 
 
 
 
 
 
 
26
  load_ohca_model,
27
  run_inference,
28
  quick_inference,
@@ -30,15 +54,35 @@ from .ohca_inference import (
30
  test_model_on_sample,
31
  get_high_confidence_cases,
32
  analyze_predictions,
 
 
33
  OHCAInferenceDataset
34
  )
35
 
36
- __version__ = "1.0.0"
37
  __author__ = "Mona Moukaddem"
38
  __email__ = "your.email@example.com"
39
 
40
- # Training pipeline functions
41
- __training_functions__ = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  "create_training_sample",
43
  "prepare_training_data",
44
  "train_ohca_model",
@@ -48,8 +92,7 @@ __training_functions__ = [
48
  "OHCATrainingDataset"
49
  ]
50
 
51
- # Inference functions
52
- __inference_functions__ = [
53
  "load_ohca_model",
54
  "run_inference",
55
  "quick_inference",
@@ -60,4 +103,54 @@ __inference_functions__ = [
60
  "OHCAInferenceDataset"
61
  ]
62
 
63
- __all__ = __training_functions__ + __inference_functions__
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ NLP OHCA Classifier v3.0 - Improved Methodology
 
3
  A BERT-based classifier for detecting Out-of-Hospital Cardiac Arrest (OHCA)
4
+ cases in medical discharge notes using improved machine learning methodology.
5
 
6
+ Key Improvements in v3.0:
7
+ - Patient-level data splits to prevent data leakage
8
+ - Proper train/validation/test methodology
9
+ - Optimal threshold finding and usage
10
+ - Larger annotation samples for better performance
11
+ - Unbiased evaluation framework
12
 
13
+ This package contains two main modules:
14
+ 1. ohca_training_pipeline: Complete training pipeline with improved methodology
15
+ 2. ohca_inference: Apply pre-trained models with optimal threshold support
16
  """
17
 
18
+ # Training pipeline imports - v3.0 with improvements
19
  from .ohca_training_pipeline import (
20
+ # Improved functions
21
+ create_patient_level_splits,
22
+ complete_improved_training_pipeline,
23
+ complete_annotation_and_train_v3,
24
+ find_optimal_threshold,
25
+ evaluate_on_test_set,
26
+ save_model_with_metadata,
27
+
28
+ # Legacy functions (backward compatible)
29
  create_training_sample,
30
  prepare_training_data,
31
  train_ohca_model,
32
  evaluate_model,
33
  complete_training_pipeline,
34
  complete_annotation_and_train,
35
+
36
+ # Dataset class
37
  OHCATrainingDataset
38
  )
39
 
40
+ # Inference imports - v3.0 with optimal threshold support
41
  from .ohca_inference import (
42
+ # New v3.0 functions with optimal threshold support
43
+ load_ohca_model_with_metadata,
44
+ run_inference_with_optimal_threshold,
45
+ quick_inference_with_optimal_threshold,
46
+ process_large_dataset_with_optimal_threshold,
47
+ analyze_predictions_enhanced,
48
+
49
+ # Legacy functions (backward compatible)
50
  load_ohca_model,
51
  run_inference,
52
  quick_inference,
 
54
  test_model_on_sample,
55
  get_high_confidence_cases,
56
  analyze_predictions,
57
+
58
+ # Dataset class
59
  OHCAInferenceDataset
60
  )
61
 
62
+ __version__ = "3.0.0"
63
  __author__ = "Mona Moukaddem"
64
  __email__ = "your.email@example.com"
65
 
66
+ # v3.0 improved functions (recommended)
67
+ __improved_training_functions__ = [
68
+ "create_patient_level_splits",
69
+ "complete_improved_training_pipeline",
70
+ "complete_annotation_and_train_v3",
71
+ "find_optimal_threshold",
72
+ "evaluate_on_test_set",
73
+ "save_model_with_metadata"
74
+ ]
75
+
76
+ __improved_inference_functions__ = [
77
+ "load_ohca_model_with_metadata",
78
+ "run_inference_with_optimal_threshold",
79
+ "quick_inference_with_optimal_threshold",
80
+ "process_large_dataset_with_optimal_threshold",
81
+ "analyze_predictions_enhanced"
82
+ ]
83
+
84
+ # Legacy functions (maintained for backward compatibility)
85
+ __legacy_training_functions__ = [
86
  "create_training_sample",
87
  "prepare_training_data",
88
  "train_ohca_model",
 
92
  "OHCATrainingDataset"
93
  ]
94
 
95
+ __legacy_inference_functions__ = [
 
96
  "load_ohca_model",
97
  "run_inference",
98
  "quick_inference",
 
103
  "OHCAInferenceDataset"
104
  ]
105
 
106
+ # All available functions
107
+ __all__ = (
108
+ __improved_training_functions__ +
109
+ __improved_inference_functions__ +
110
+ __legacy_training_functions__ +
111
+ __legacy_inference_functions__
112
+ )
113
+
114
+ # Methodology information
115
+ __methodology_version__ = "3.0"
116
+ __improvements__ = [
117
+ "Patient-level data splits prevent data leakage",
118
+ "Proper train/validation/test methodology",
119
+ "Optimal threshold finding and consistent usage",
120
+ "Larger annotation samples (800 train + 200 val)",
121
+ "Unbiased evaluation on independent test set",
122
+ "Enhanced clinical decision support",
123
+ "Backward compatibility with legacy models"
124
+ ]
125
+
126
+ def get_version_info():
127
+ """Return detailed version and methodology information"""
128
+ return {
129
+ 'version': __version__,
130
+ 'methodology_version': __methodology_version__,
131
+ 'improvements': __improvements__,
132
+ 'author': __author__,
133
+ 'recommended_functions': {
134
+ 'training': 'complete_improved_training_pipeline',
135
+ 'inference': 'quick_inference_with_optimal_threshold'
136
+ }
137
+ }
138
+
139
+ def print_welcome_message():
140
+ """Print welcome message with key improvements"""
141
+ print("="*60)
142
+ print("NLP OHCA Classifier v3.0 - Improved Methodology")
143
+ print("="*60)
144
+ print("Key improvements addressing data scientist feedback:")
145
+ for improvement in __improvements__:
146
+ print(f"✅ {improvement}")
147
+ print()
148
+ print("Recommended functions:")
149
+ print("• Training: complete_improved_training_pipeline()")
150
+ print("• Inference: quick_inference_with_optimal_threshold()")
151
+ print()
152
+ print("Legacy functions maintained for backward compatibility.")
153
+ print("="*60)
154
+
155
+ # Print welcome message when package is imported
156
+ print_welcome_message()