PrashantRGore commited on
Commit
0f21e9b
Β·
1 Parent(s): 1f12230

Deploy drug-causality-bert v1 with BioBERT model and caching optimizations

Browse files
app/.ipynb_checkpoints/requirements-checkpoint.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ transformers>=4.35.0
3
+ pandas
4
+ numpy
5
+ scikit-learn
6
+ nltk>=3.7
7
+ PyPDF2>=3.0.1
8
+ streamlit>=1.22.0
9
+ safetensors>=0.4.0
app/.ipynb_checkpoints/streamlit_app-checkpoint.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tempfile
3
+ import os
4
+ import sys
5
+ from pathlib import Path
6
+ import nltk
7
+
8
+ nltk.download('punkt')
9
+
10
+ # Add parent directory to Python path for imports
11
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
12
+
13
+ from src.inference import CausalityClassifier, extract_text_from_pdf, classify_causality, process_pdf_file, process_multiple_pdfs
14
+
15
+ # App Configuration
16
+ st.set_page_config(
17
+ page_title="Drug Causality Classifier",
18
+ page_icon="πŸ’Š",
19
+ layout="wide",
20
+ initial_sidebar_state="expanded"
21
+ )
22
+
23
+ # Main Title
24
+ st.title("πŸ’Š Drug Causality Classifier")
25
+ st.caption("BioBERT Model | F1 Score: 97.59% | Sensitivity: 98.68% | Specificity: 96.50%")
26
+
27
+ # Load model (cached)
28
+ @st.cache_resource
29
+ def load_model():
30
+ try:
31
+ return CausalityClassifier("models/production_model_final")
32
+ except Exception as e:
33
+ st.error(f"Failed to load model: {e}")
34
+ return None
35
+
36
+ classifier = load_model()
37
+
38
+ # Sidebar Configuration
39
+ st.sidebar.header("βš™οΈ Configuration")
40
+ threshold = st.sidebar.slider(
41
+ "Classification Threshold",
42
+ min_value=0.0,
43
+ max_value=1.0,
44
+ value=0.5,
45
+ step=0.05,
46
+ help="Higher threshold = stricter causality detection"
47
+ )
48
+
49
+ st.sidebar.info(
50
+ "**Threshold Guide:**\n"
51
+ "- 0.3-0.4: High sensitivity (catch all events)\n"
52
+ "- 0.5: Balanced performance\n"
53
+ "- 0.7-0.8: High precision (reduce false alarms)"
54
+ )
55
+
56
+ # Main Content
57
+ tab1, tab2, tab3 = st.tabs(["πŸ“ Single Text", "πŸ“„ PDF Analysis", "πŸ“ Batch Processing"])
58
+
59
+ # TAB 1: Single Text Classification
60
+ with tab1:
61
+ st.header("πŸ“ Single Statement Classification")
62
+ st.write("Enter medical text to classify drug-adverse event causality:")
63
+
64
+ text_input = st.text_area(
65
+ "Medical Text:",
66
+ height=150,
67
+ placeholder="e.g., Patient developed severe nausea and vomiting 2 hours after taking Drug X. Clinical assessment confirmed drug-related causality."
68
+ )
69
+
70
+ col1, col2 = st.columns([2, 1])
71
+
72
+ with col1:
73
+ if st.button("πŸ” Classify Text", type="primary", use_container_width=True):
74
+ if text_input and classifier:
75
+ with st.spinner("Analyzing text..."):
76
+ result = classifier.predict(text_input, threshold)
77
+
78
+ # Display Results
79
+ st.subheader("πŸ“Š Results")
80
+
81
+ result_col1, result_col2 = st.columns(2)
82
+
83
+ with result_col1:
84
+ classification = result['prediction'].upper()
85
+ color = "green" if result['prediction'] == 'related' else "red"
86
+ st.markdown(f"**Classification:** :{color}[{classification}]")
87
+
88
+ with result_col2:
89
+ confidence_pct = result['confidence'] * 100
90
+ st.metric("Confidence", f"{confidence_pct:.1f}%")
91
+
92
+ # Probability Distribution
93
+ st.subheader("πŸ“ˆ Probability Distribution")
94
+ probs = result['probabilities']
95
+
96
+ # Progress bars
97
+ st.write("**Related (Drug-Caused):**")
98
+ st.progress(probs['related'], text=f"{probs['related']:.2%}")
99
+
100
+ st.write("**Not Related:**")
101
+ st.progress(probs['not_related'], text=f"{probs['not_related']:.2%}")
102
+
103
+ # Raw JSON Output
104
+ with st.expander("πŸ” View Raw Results"):
105
+ st.json(result)
106
+
107
+ elif not classifier:
108
+ st.error("Model not loaded properly.")
109
+ else:
110
+ st.warning("Please enter text to classify.")
111
+
112
+ with col2:
113
+ st.info(
114
+ "**Example Inputs:**\n\n"
115
+ "**Related:** _Patient developed rash after taking aspirin. Symptoms resolved after discontinuation._\n\n"
116
+ "**Not Related:** _Patient has a history of diabetes and hypertension. Takes metformin daily._"
117
+ )
118
+
119
+ # TAB 2: PDF Analysis
120
+ with tab2:
121
+ st.header("πŸ“„ PDF Document Analysis")
122
+ st.write("Upload a PDF document for comprehensive drug-adverse event analysis:")
123
+
124
+ pdf_file = st.file_uploader(
125
+ "Choose a PDF file",
126
+ type=["pdf"],
127
+ help="Upload medical documents, case reports, or clinical notes"
128
+ )
129
+
130
+ if pdf_file and classifier:
131
+ # Save uploaded file temporarily
132
+ temp_dir = tempfile.gettempdir()
133
+ temp_path = os.path.join(temp_dir, pdf_file.name)
134
+
135
+ with open(temp_path, "wb") as tmp_f:
136
+ tmp_f.write(pdf_file.getbuffer())
137
+
138
+ # Analysis Button
139
+ if st.button("πŸ” Analyze PDF", type="primary", use_container_width=True):
140
+ with st.spinner(f"Processing {pdf_file.name}..."):
141
+ try:
142
+ # Extract and classify
143
+ pdf_text = extract_text_from_pdf(temp_path)
144
+ results = classify_causality(pdf_text, threshold=threshold)
145
+
146
+ # Display Summary
147
+ st.subheader("πŸ“Š Analysis Summary")
148
+
149
+ summary_col1, summary_col2, summary_col3 = st.columns(3)
150
+
151
+ with summary_col1:
152
+ classification = results['final_classification'].upper()
153
+ color = "green" if results['final_classification'] == 'related' else "red"
154
+ st.markdown(f"**Overall:** :{color}[{classification}]")
155
+
156
+ with summary_col2:
157
+ confidence_pct = results['confidence_score'] * 100
158
+ st.metric("Confidence", f"{confidence_pct:.1f}%")
159
+
160
+ with summary_col3:
161
+ st.metric("Total Sentences", results['total_sentences'])
162
+
163
+ # Sentence Breakdown
164
+ st.subheader("πŸ” Sentence Analysis")
165
+
166
+ breakdown_col1, breakdown_col2 = st.columns(2)
167
+
168
+ with breakdown_col1:
169
+ st.metric("Related Sentences", results['related_sentences'])
170
+
171
+ with breakdown_col2:
172
+ st.metric("Not Related", results['not_related_sentences'])
173
+
174
+ # Top Related Sentences
175
+ if results['related_sentences'] > 0:
176
+ st.subheader("🎯 Top Related Sentences")
177
+
178
+ for i, sent_detail in enumerate(results.get('top_related_sentences', []), 1):
179
+ confidence = sent_detail['probability_related']
180
+ confidence_color = "green" if confidence > 0.7 else "orange" if confidence > 0.5 else "red"
181
+
182
+ st.markdown(f"**{i}.** ({confidence:.1%} confidence)")
183
+ st.markdown(f":{confidence_color}[{sent_detail['sentence']}]")
184
+ st.write("")
185
+
186
+ # Download Button
187
+ st.subheader("πŸ’Ύ Download Report")
188
+
189
+ import json
190
+ report_json = json.dumps(results, indent=2)
191
+
192
+ st.download_button(
193
+ label="πŸ“₯ Download JSON Report",
194
+ data=report_json,
195
+ file_name=f"{pdf_file.name}_causality_report.json",
196
+ mime="application/json"
197
+ )
198
+
199
+ # Raw Results Expander
200
+ with st.expander("πŸ” View Full Results"):
201
+ st.json(results)
202
+
203
+ except Exception as e:
204
+ st.error(f"Error processing PDF: {str(e)}")
205
+ st.info("Please ensure the PDF contains readable text and try again.")
206
+
207
+ # Clean up temp file
208
+ finally:
209
+ try:
210
+ os.remove(temp_path)
211
+ except:
212
+ pass
213
+
214
+ # TAB 3: Batch Processing
215
+ with tab3:
216
+ st.header("πŸ“ Batch PDF Processing")
217
+ st.write("Upload multiple PDF files for batch causality analysis:")
218
+
219
+ batch_files = st.file_uploader(
220
+ "Choose PDF files",
221
+ type=["pdf"],
222
+ accept_multiple_files=True,
223
+ help="Upload multiple medical documents for batch analysis"
224
+ )
225
+
226
+ if batch_files and classifier:
227
+ st.write(f"**Selected files:** {len(batch_files)} PDFs")
228
+
229
+ for i, file in enumerate(batch_files, 1):
230
+ st.write(f"{i}. {file.name}")
231
+
232
+ if st.button("πŸ” Process All PDFs", type="primary", use_container_width=True):
233
+ # Create temporary paths for all files
234
+ batch_temp_paths = []
235
+ temp_dir = tempfile.gettempdir()
236
+
237
+ try:
238
+ # Save all files temporarily
239
+ for batch_file in batch_files:
240
+ temp_path = os.path.join(temp_dir, batch_file.name)
241
+ with open(temp_path, "wb") as tmp_f:
242
+ tmp_f.write(batch_file.getbuffer())
243
+ batch_temp_paths.append(temp_path)
244
+
245
+ # Process all files
246
+ with st.spinner(f"Processing {len(batch_files)} files..."):
247
+ batch_results = process_multiple_pdfs(batch_temp_paths, threshold=threshold)
248
+
249
+ # Display Batch Summary
250
+ st.subheader("πŸ“Š Batch Analysis Summary")
251
+
252
+ # Overall stats
253
+ total_files = len(batch_results)
254
+ successful = len([r for r in batch_results if 'error' not in r])
255
+ related_count = len([r for r in batch_results if r.get('final_classification') == 'related'])
256
+
257
+ stat_col1, stat_col2, stat_col3 = st.columns(3)
258
+
259
+ with stat_col1:
260
+ st.metric("Total Files", total_files)
261
+
262
+ with stat_col2:
263
+ st.metric("Successfully Processed", successful)
264
+
265
+ with stat_col3:
266
+ st.metric("Drug-Related Files", related_count)
267
+
268
+ # Individual Results
269
+ st.subheader("πŸ“„ Individual Results")
270
+
271
+ for i, res in enumerate(batch_results, 1):
272
+ if 'error' in res:
273
+ st.error(f"**{i}. {res['pdf_file']}:** Error - {res['error']}")
274
+ else:
275
+ classification = res['final_classification'].upper()
276
+ confidence = res.get('confidence_score', 0) * 100
277
+ color = "green" if res['final_classification'] == 'related' else "red"
278
+
279
+ st.markdown(f"**{i}. {res['pdf_file']}:** :{color}[{classification}] ({confidence:.1f}% confidence)")
280
+
281
+ # Download Batch Summary
282
+ st.subheader("πŸ’Ύ Download Batch Report")
283
+
284
+ import json
285
+ batch_report = {
286
+ 'summary': {
287
+ 'total_files': total_files,
288
+ 'successful': successful,
289
+ 'related_count': related_count,
290
+ 'threshold_used': threshold
291
+ },
292
+ 'individual_results': batch_results
293
+ }
294
+
295
+ batch_json = json.dumps(batch_report, indent=2)
296
+
297
+ st.download_button(
298
+ label="πŸ“₯ Download Batch Summary",
299
+ data=batch_json,
300
+ file_name="batch_causality_summary.json",
301
+ mime="application/json"
302
+ )
303
+
304
+ # Raw Results Expander
305
+ with st.expander("πŸ” View Full Batch Results"):
306
+ st.json(batch_results)
307
+
308
+ except Exception as e:
309
+ st.error(f"Batch processing error: {str(e)}")
310
+
311
+ finally:
312
+ # Clean up all temp files
313
+ for temp_path in batch_temp_paths:
314
+ try:
315
+ os.remove(temp_path)
316
+ except:
317
+ pass
318
+
319
+ # Footer
320
+ st.markdown("---")
321
+ st.markdown(
322
+ "**Built with BioBERT for Pharmacovigilance** | "
323
+ "Developed for clinical decision support and regulatory compliance"
324
+ )
325
+
326
+ # Sidebar additional info
327
+ st.sidebar.markdown("---")
328
+ st.sidebar.markdown("### πŸ“ˆ Model Performance")
329
+ st.sidebar.markdown(
330
+ "- **F1 Score:** 97.59%\n"
331
+ "- **Accuracy:** 97.59%\n"
332
+ "- **Sensitivity:** 98.68%\n"
333
+ "- **Specificity:** 96.50%"
334
+ )
335
+
336
+ st.sidebar.markdown("### πŸ₯ Clinical Use")
337
+ st.sidebar.markdown(
338
+ "This tool assists in:\n"
339
+ "- Adverse event detection\n"
340
+ "- Pharmacovigilance screening\n"
341
+ "- Clinical report analysis\n"
342
+ "- Regulatory compliance"
343
+ )
344
+
app/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ transformers>=4.35.0
3
+ pandas
4
+ numpy
5
+ scikit-learn
6
+ nltk>=3.7
7
+ PyPDF2>=3.0.1
8
+ streamlit>=1.22.0
9
+ safetensors>=0.4.0
app/streamlit_app.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tempfile
3
+ import os
4
+ import sys
5
+ import pandas as pd
6
+ import json
7
+ from pathlib import Path
8
+ import nltk
9
+
10
+ nltk.download('punkt')
11
+
12
+ # Add parent directory to path
13
+ sys.path.insert(0, str(Path(__file__).parent.parent))
14
+
15
+ # NOW THIS IMPORT WILL WORK!
16
+ from src.inference import (
17
+ CausalityClassifier,
18
+ extract_text_from_pdf,
19
+ classify_causality,
20
+ process_pdf_file,
21
+ process_multiple_pdfs
22
+ )
23
+
24
+ # SINGLE load_model function with caching
25
+ @st.cache_resource
26
+ def load_model():
27
+ """Load CausalityClassifier model once and reuse across sessions"""
28
+ try:
29
+ return CausalityClassifier("models/production_model_final")
30
+ except Exception as e:
31
+ st.error(f"Failed to load model: {e}")
32
+ return None
33
+
34
+ # App Configuration
35
+ st.set_page_config(
36
+ page_title="Drug Causality Classifier",
37
+ page_icon="πŸ’Š",
38
+ layout="wide",
39
+ initial_sidebar_state="expanded"
40
+ )
41
+
42
+ # Main Title
43
+ st.title("πŸ’Š Drug Causality Classifier")
44
+ st.caption("BioBERT Model | F1 Score: 97.59% | Sensitivity: 98.68% | Specificity: 96.50%")
45
+
46
+ # Load model (cached)
47
+ classifier = load_model()
48
+
49
+ # Sidebar Configuration
50
+ st.sidebar.header("βš™οΈ Configuration")
51
+ threshold = st.sidebar.slider(
52
+ "Classification Threshold",
53
+ min_value=0.0,
54
+ max_value=1.0,
55
+ value=0.5,
56
+ step=0.05,
57
+ help="Higher threshold = stricter causality detection"
58
+ )
59
+
60
+ st.sidebar.info(
61
+ "**Threshold Guide:**\n"
62
+ "- 0.3-0.4: High sensitivity (catch all events)\n"
63
+ "- 0.5: Balanced performance\n"
64
+ "- 0.7-0.8: High precision (reduce false alarms)"
65
+ )
66
+
67
+ # Main Content
68
+ tab1, tab2, tab3 = st.tabs(["πŸ“ Single Text", "πŸ“„ PDF Analysis", "πŸ“ Batch Processing"])
69
+
70
+ # TAB 1: Single Text Classification
71
+ with tab1:
72
+ st.header("πŸ“ Single Statement Classification")
73
+ st.write("Enter medical text to classify drug-adverse event causality:")
74
+
75
+ text_input = st.text_area(
76
+ "Medical Text:",
77
+ height=150,
78
+ placeholder="e.g., Patient developed severe nausea and vomiting 2 hours after taking Drug X. Clinical assessment confirmed drug-related causality."
79
+ )
80
+
81
+ col1, col2 = st.columns([2, 1])
82
+
83
+ with col1:
84
+ if st.button("πŸ” Classify Text", type="primary", use_container_width=True):
85
+ if text_input and classifier:
86
+ with st.spinner("Analyzing text..."):
87
+ result = classifier.predict(text_input, threshold)
88
+
89
+ # Display Results
90
+ st.subheader("πŸ“Š Results")
91
+
92
+ result_col1, result_col2 = st.columns(2)
93
+
94
+ with result_col1:
95
+ classification = result['prediction'].upper()
96
+ color = "green" if result['prediction'] == 'related' else "red"
97
+ st.markdown(f"**Classification:** :{color}[{classification}]")
98
+
99
+ with result_col2:
100
+ confidence_pct = result['confidence'] * 100
101
+ st.metric("Confidence", f"{confidence_pct:.1f}%")
102
+
103
+ # Probability Distribution
104
+ st.subheader("πŸ“ˆ Probability Distribution")
105
+ probs = result['probabilities']
106
+
107
+ # Progress bars
108
+ st.write("**Related (Drug-Caused):**")
109
+ st.progress(probs['related'], text=f"{probs['related']:.2%}")
110
+
111
+ st.write("**Not Related:**")
112
+ st.progress(probs['not_related'], text=f"{probs['not_related']:.2%}")
113
+
114
+ # Raw JSON Output
115
+ with st.expander("πŸ” View Raw Results"):
116
+ st.json(result)
117
+
118
+ elif not classifier:
119
+ st.error("Model not loaded properly.")
120
+ else:
121
+ st.warning("Please enter text to classify.")
122
+
123
+ with col2:
124
+ st.info(
125
+ "**Example Inputs:**\n\n"
126
+ "**Related:** _Patient developed rash after taking aspirin. Symptoms resolved after discontinuation._\n\n"
127
+ "**Not Related:** _Patient has a history of diabetes and hypertension. Takes metformin daily._"
128
+ )
129
+
130
+ # TAB 2: PDF Analysis
131
+ with tab2:
132
+ st.header("πŸ“„ PDF Document Analysis")
133
+ st.write("Upload a PDF document for comprehensive drug-adverse event analysis:")
134
+
135
+ pdf_file = st.file_uploader(
136
+ "Choose a PDF file",
137
+ type=["pdf"],
138
+ help="Upload medical documents, case reports, or clinical notes"
139
+ )
140
+
141
+ if pdf_file and classifier:
142
+ # Save uploaded file temporarily
143
+ temp_dir = tempfile.gettempdir()
144
+ temp_path = os.path.join(temp_dir, pdf_file.name)
145
+
146
+ with open(temp_path, "wb") as tmp_f:
147
+ tmp_f.write(pdf_file.getbuffer())
148
+
149
+ # Analysis Button
150
+ if st.button("πŸ” Analyze PDF", type="primary", use_container_width=True):
151
+ with st.spinner(f"Processing {pdf_file.name}..."):
152
+ try:
153
+ # Extract and classify
154
+ pdf_text = extract_text_from_pdf(temp_path)
155
+ results = classify_causality(pdf_text, threshold=threshold)
156
+
157
+ # Display Summary
158
+ st.subheader("πŸ“Š Analysis Summary")
159
+
160
+ summary_col1, summary_col2, summary_col3 = st.columns(3)
161
+
162
+ with summary_col1:
163
+ classification = results['final_classification'].upper()
164
+ color = "green" if results['final_classification'] == 'related' else "red"
165
+ st.markdown(f"**Overall:** :{color}[{classification}]")
166
+
167
+ with summary_col2:
168
+ confidence_pct = results['confidence_score'] * 100
169
+ st.metric("Confidence", f"{confidence_pct:.1f}%")
170
+
171
+ with summary_col3:
172
+ st.metric("Total Sentences", results['total_sentences'])
173
+
174
+ # Sentence Breakdown
175
+ st.subheader("πŸ” Sentence Analysis")
176
+
177
+ breakdown_col1, breakdown_col2 = st.columns(2)
178
+
179
+ with breakdown_col1:
180
+ st.metric("Related Sentences", results['related_sentences'])
181
+
182
+ with breakdown_col2:
183
+ st.metric("Not Related", results['not_related_sentences'])
184
+
185
+ # Top Related Sentences
186
+ if results['related_sentences'] > 0:
187
+ st.subheader("🎯 Top Related Sentences")
188
+
189
+ for i, sent_detail in enumerate(results.get('top_related_sentences', []), 1):
190
+ confidence = sent_detail['probability_related']
191
+ confidence_color = "green" if confidence > 0.7 else "orange" if confidence > 0.5 else "red"
192
+
193
+ st.markdown(f"**{i}.** ({confidence:.1%} confidence)")
194
+ st.markdown(f":{confidence_color}[{sent_detail['sentence']}]")
195
+ st.write("")
196
+
197
+ # Download Button
198
+ st.subheader("πŸ’Ύ Download Report")
199
+
200
+ report_json = json.dumps(results, indent=2)
201
+
202
+ st.download_button(
203
+ label="πŸ“₯ Download JSON Report",
204
+ data=report_json,
205
+ file_name=f"{pdf_file.name}_causality_report.json",
206
+ mime="application/json"
207
+ )
208
+
209
+ # Raw Results Expander
210
+ with st.expander("πŸ” View Full Results"):
211
+ st.json(results)
212
+
213
+ except Exception as e:
214
+ st.error(f"Error processing PDF: {str(e)}")
215
+ st.info("Please ensure the PDF contains readable text and try again.")
216
+
217
+ # Clean up temp file
218
+ finally:
219
+ try:
220
+ os.remove(temp_path)
221
+ except:
222
+ pass
223
+
224
+ # TAB 3: Batch Processing
225
+ with tab3:
226
+ st.header("πŸ“ Batch PDF Processing")
227
+ st.write("Upload multiple PDF files for batch causality analysis:")
228
+
229
+ batch_files = st.file_uploader(
230
+ "Choose PDF files",
231
+ type=["pdf"],
232
+ accept_multiple_files=True,
233
+ help="Upload multiple medical documents for batch analysis"
234
+ )
235
+
236
+ if batch_files and classifier:
237
+ st.write(f"**Selected files:** {len(batch_files)} PDFs")
238
+
239
+ for i, file in enumerate(batch_files, 1):
240
+ st.write(f"{i}. {file.name}")
241
+
242
+ if st.button("πŸ” Process All PDFs", type="primary", use_container_width=True):
243
+ # Create temporary paths for all files
244
+ batch_temp_paths = []
245
+ temp_dir = tempfile.gettempdir()
246
+
247
+ try:
248
+ # Save all files temporarily
249
+ for batch_file in batch_files:
250
+ temp_path = os.path.join(temp_dir, batch_file.name)
251
+ with open(temp_path, "wb") as tmp_f:
252
+ tmp_f.write(batch_file.getbuffer())
253
+ batch_temp_paths.append(temp_path)
254
+
255
+ # Process all files
256
+ with st.spinner(f"Processing {len(batch_files)} files..."):
257
+ batch_results = process_multiple_pdfs(batch_temp_paths, threshold=threshold)
258
+
259
+ # Display Batch Summary
260
+ st.subheader("πŸ“Š Batch Analysis Summary")
261
+
262
+ # Overall stats
263
+ total_files = len(batch_results)
264
+ successful = len([r for r in batch_results if 'error' not in r])
265
+ related_count = len([r for r in batch_results if r.get('final_classification') == 'related'])
266
+
267
+ stat_col1, stat_col2, stat_col3 = st.columns(3)
268
+
269
+ with stat_col1:
270
+ st.metric("Total Files", total_files)
271
+
272
+ with stat_col2:
273
+ st.metric("Successfully Processed", successful)
274
+
275
+ with stat_col3:
276
+ st.metric("Drug-Related Files", related_count)
277
+
278
+ # Individual Results
279
+ st.subheader("πŸ“„ Individual Results")
280
+
281
+ for i, res in enumerate(batch_results, 1):
282
+ if 'error' in res:
283
+ st.error(f"**{i}. {res['pdf_file']}:** Error - {res['error']}")
284
+ else:
285
+ classification = res['final_classification'].upper()
286
+ confidence = res.get('confidence_score', 0) * 100
287
+ color = "green" if res['final_classification'] == 'related' else "red"
288
+
289
+ st.markdown(f"**{i}. {res['pdf_file']}:** :{color}[{classification}] ({confidence:.1f}% confidence)")
290
+
291
+ # Download Batch Summary
292
+ st.subheader("πŸ’Ύ Download Batch Report")
293
+
294
+ batch_report = {
295
+ 'summary': {
296
+ 'total_files': total_files,
297
+ 'successful': successful,
298
+ 'related_count': related_count,
299
+ 'threshold_used': threshold
300
+ },
301
+ 'individual_results': batch_results
302
+ }
303
+
304
+ batch_json = json.dumps(batch_report, indent=2)
305
+
306
+ st.download_button(
307
+ label="πŸ“₯ Download Batch Summary",
308
+ data=batch_json,
309
+ file_name="batch_causality_summary.json",
310
+ mime="application/json"
311
+ )
312
+
313
+ # Raw Results Expander
314
+ with st.expander("πŸ” View Full Batch Results"):
315
+ st.json(batch_results)
316
+
317
+ except Exception as e:
318
+ st.error(f"Batch processing error: {str(e)}")
319
+
320
+ finally:
321
+ # Clean up all temp files
322
+ for temp_path in batch_temp_paths:
323
+ try:
324
+ os.remove(temp_path)
325
+ except:
326
+ pass
327
+
328
+ # Footer
329
+ st.markdown("---")
330
+ st.markdown(
331
+ "**Built with BioBERT for Pharmacovigilance** | "
332
+ "Developed for clinical decision support and regulatory compliance"
333
+ )
334
+
335
+ # Sidebar additional info
336
+ st.sidebar.markdown("---")
337
+ st.sidebar.markdown("### πŸ“ˆ Model Performance")
338
+ st.sidebar.markdown(
339
+ "- **F1 Score:** 97.59%\n"
340
+ "- **Accuracy:** 97.59%\n"
341
+ "- **Sensitivity:** 98.68%\n"
342
+ "- **Specificity:** 96.50%"
343
+ )
344
+
345
+ st.sidebar.markdown("### πŸ₯ Clinical Use")
346
+ st.sidebar.markdown(
347
+ "This tool assists in:\n"
348
+ "- Adverse event detection\n"
349
+ "- Pharmacovigilance screening\n"
350
+ "- Clinical report analysis\n"
351
+ "- Regulatory compliance"
352
+ )
models/production_model_final/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForSequenceClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "classifier_dropout": null,
7
+ "dtype": "float32",
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 768,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 3072,
13
+ "layer_norm_eps": 1e-12,
14
+ "max_position_embeddings": 512,
15
+ "model_type": "bert",
16
+ "num_attention_heads": 12,
17
+ "num_hidden_layers": 12,
18
+ "pad_token_id": 0,
19
+ "position_embedding_type": "absolute",
20
+ "problem_type": "single_label_classification",
21
+ "transformers_version": "4.57.1",
22
+ "type_vocab_size": 2,
23
+ "use_cache": true,
24
+ "vocab_size": 30522
25
+ }
models/production_model_final/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f73202c120a52a4288c045e5713eeecbfe7b3431b5e15dafa3ded35b8ba18e4
3
+ size 437958648
models/production_model_final/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
models/production_model_final/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
models/production_model_final/tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "extra_special_tokens": {},
49
+ "mask_token": "[MASK]",
50
+ "model_max_length": 1000000000000000019884624838656,
51
+ "never_split": null,
52
+ "pad_token": "[PAD]",
53
+ "sep_token": "[SEP]",
54
+ "strip_accents": null,
55
+ "tokenize_chinese_chars": true,
56
+ "tokenizer_class": "BertTokenizer",
57
+ "unk_token": "[UNK]"
58
+ }
models/production_model_final/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca6a57bb3665602515473f6c3e6aa96cb5505d7b8642beb6c8604c4a00aec451
3
+ size 5777
models/production_model_final/training_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_params": {
3
+ "learning_rate": 3.758180918998249e-05,
4
+ "num_train_epochs": 1,
5
+ "batch_size": 4,
6
+ "gradient_accumulation_steps": 4
7
+ },
8
+ "final_results": {
9
+ "accuracy": 0.9758909853249476,
10
+ "f1": 0.9758881040529953,
11
+ "precision": 0.9761185621296976,
12
+ "recall": 0.9758909853249476
13
+ },
14
+ "optuna_source": "Trial 1",
15
+ "training_date": "2025-10-25T16:06:34.368403"
16
+ }
models/production_model_final/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1,3 +1,11 @@
1
- altair
 
2
  pandas
3
- streamlit
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ transformers>=4.35.0
3
  pandas
4
+ numpy
5
+ scikit-learn
6
+ nltk>=3.7
7
+ PyPDF2>=3.0.1
8
+ streamlit>=1.22.0
9
+ safetensors>=0.4.0
10
+ pip install boto3
11
+ pip freeze > requirements.txt
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (147 Bytes). View file
 
src/__pycache__/inference.cpython-313.pyc ADDED
Binary file (20.3 kB). View file
 
src/inference.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ο»Ώimport torch
2
+ import numpy as np
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ from pathlib import Path
5
+ import PyPDF2
6
+ import json
7
+ from datetime import datetime
8
+ from typing import Union, List, Dict
9
+ import re
10
+
11
+ # NLTK with robust error handling
12
+ import nltk
13
+ import ssl
14
+
15
+ # SSL fix for NLTK
16
+ try:
17
+ _create_unverified_https_context = ssl._create_unverified_context
18
+ except AttributeError:
19
+ pass
20
+ else:
21
+ ssl._create_default_https_context = _create_unverified_https_context
22
+
23
+ # Enhanced NLTK data download with retry
24
+ def download_nltk_data_robust():
25
+ """Download NLTK data with multiple attempts and fallbacks"""
26
+ import os
27
+
28
+ # Set NLTK data path explicitly
29
+ nltk_data_dir = '/home/appuser/nltk_data'
30
+ if not os.path.exists(nltk_data_dir):
31
+ try:
32
+ os.makedirs(nltk_data_dir, exist_ok=True)
33
+ except:
34
+ pass
35
+
36
+ if nltk_data_dir not in nltk.data.path:
37
+ nltk.data.path.insert(0, nltk_data_dir)
38
+
39
+ packages = ['punkt', 'punkt_tab']
40
+ for package in packages:
41
+ for attempt in range(3): # Try 3 times
42
+ try:
43
+ nltk.data.find(f'tokenizers/{package}')
44
+ print(f"βœ“ {package} already available")
45
+ break
46
+ except LookupError:
47
+ try:
48
+ print(f"Downloading {package} (attempt {attempt + 1})...")
49
+ nltk.download(package, download_dir=nltk_data_dir, quiet=False)
50
+ print(f"βœ“ {package} downloaded successfully")
51
+ break
52
+ except Exception as e:
53
+ print(f"Warning: Could not download {package}: {e}")
54
+ if attempt == 2:
55
+ print(f"Failed to download {package} after 3 attempts")
56
+
57
+ # Download on import
58
+ download_nltk_data_robust()
59
+
60
+ # Fallback sentence tokenizer using regex
61
+ def simple_sentence_tokenize(text):
62
+ """Simple regex-based sentence tokenizer as fallback"""
63
+ # Split on common sentence boundaries
64
+ sentences = re.split(r'(?<=[.!?])\s+', text)
65
+ return [s.strip() for s in sentences if s.strip()]
66
+
67
+ # Safe sentence tokenization with fallback
68
+ def safe_sent_tokenize(text):
69
+ """Tokenize with NLTK, fallback to regex if NLTK fails"""
70
+ try:
71
+ from nltk.tokenize import sent_tokenize
72
+ return sent_tokenize(text)
73
+ except Exception as e:
74
+ print(f"NLTK tokenization failed ({e}), using fallback...")
75
+ return simple_sentence_tokenize(text)
76
+
77
+ class CausalityClassifier:
78
+ def __init__(self, model_path='./models/production_model_final', threshold=0.5):
79
+ self.model_path = Path(model_path)
80
+ self.threshold = threshold
81
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
82
+ self.model = AutoModelForSequenceClassification.from_pretrained(self.model_path)
83
+ self.model.eval()
84
+
85
+ def predict(self, text, return_probs=False):
86
+ inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=96)
87
+ with torch.no_grad():
88
+ outputs = self.model(**inputs)
89
+ probs = torch.softmax(outputs.logits, dim=1).numpy()[0]
90
+ pred = 1 if probs[1] > self.threshold else 0
91
+ result = {
92
+ 'prediction': 'related' if pred == 1 else 'not related',
93
+ 'confidence': float(probs[pred]),
94
+ 'label': int(pred)
95
+ }
96
+ if return_probs:
97
+ result['probabilities'] = {
98
+ 'not_related': float(probs[0]),
99
+ 'related': float(probs[1])
100
+ }
101
+ return result
102
+
103
+ def extract_text_from_pdf(pdf_path):
104
+ with open(pdf_path, 'rb') as file:
105
+ pdf_reader = PyPDF2.PdfReader(file)
106
+ text = ""
107
+ for page in pdf_reader.pages:
108
+ text += page.extract_text()
109
+ return text
110
+
111
+ def classify_causality(pdf_text, model_path='./models/production_model_final', threshold=0.5, verbose=False):
112
+ classifier = CausalityClassifier(model_path, threshold)
113
+
114
+ # Use safe tokenization with fallback
115
+ sentences = safe_sent_tokenize(pdf_text)
116
+
117
+ if verbose:
118
+ print(f"Tokenized {len(sentences)} sentences")
119
+
120
+ related_count = 0
121
+ sentence_details = []
122
+
123
+ for sent in sentences:
124
+ if not sent.strip():
125
+ continue
126
+
127
+ result = classifier.predict(sent, return_probs=True)
128
+ if result['label'] == 1:
129
+ related_count += 1
130
+ sentence_details.append({
131
+ 'sentence': sent[:100],
132
+ 'probability_related': result['probabilities']['related'],
133
+ 'confidence': result['confidence']
134
+ })
135
+
136
+ sentence_details.sort(key=lambda x: x['probability_related'], reverse=True)
137
+
138
+ return {
139
+ 'final_classification': 'related' if related_count > 0 else 'not related',
140
+ 'confidence_score': related_count / len(sentences) if sentences else 0,
141
+ 'related_sentences': related_count,
142
+ 'total_sentences': len(sentences),
143
+ 'top_related_sentences': sentence_details[:5],
144
+ 'threshold_used': threshold
145
+ }
146
+
147
+ def process_pdf_file(pdf_path, model_path='./models/production_model_final', threshold=0.5, save_report=False, output_dir='./results'):
148
+ pdf_text = extract_text_from_pdf(pdf_path)
149
+ results = classify_causality(pdf_text, model_path, threshold)
150
+ results['pdf_file'] = str(Path(pdf_path).name)
151
+ if save_report:
152
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
153
+ with open(Path(output_dir) / f"{Path(pdf_path).stem}_report.json", 'w') as f:
154
+ json.dump(results, f, indent=2)
155
+ return results
156
+
157
+ def process_multiple_pdfs(pdf_paths, model_path='./models/production_model_final', threshold=0.5, save_reports=False, output_dir='./results'):
158
+ all_results = []
159
+ for pdf_path in pdf_paths:
160
+ try:
161
+ results = process_pdf_file(pdf_path, model_path, threshold, save_reports, output_dir)
162
+ all_results.append(results)
163
+ except Exception as e:
164
+ all_results.append({
165
+ 'pdf_file': str(Path(pdf_path).name),
166
+ 'error': str(e),
167
+ 'final_classification': 'error'
168
+ })
169
+ return all_results
streamlit_app.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tempfile
3
+ import os
4
+ import sys
5
+ import pandas as pd
6
+ import json
7
+ from pathlib import Path
8
+ import nltk
9
+
10
+ nltk.download('punkt')
11
+
12
+ # Add parent directory to path
13
+ sys.path.insert(0, str(Path(__file__).parent.parent))
14
+
15
+ # NOW THIS IMPORT WILL WORK!
16
+ from src.inference import (
17
+ CausalityClassifier,
18
+ extract_text_from_pdf,
19
+ classify_causality,
20
+ process_pdf_file,
21
+ process_multiple_pdfs
22
+ )
23
+
24
+ # SINGLE load_model function with caching
25
+ @st.cache_resource
26
+ def load_model():
27
+ """Load CausalityClassifier model once and reuse across sessions"""
28
+ try:
29
+ return CausalityClassifier("models/production_model_final")
30
+ except Exception as e:
31
+ st.error(f"Failed to load model: {e}")
32
+ return None
33
+
34
+ # App Configuration
35
+ st.set_page_config(
36
+ page_title="Drug Causality Classifier",
37
+ page_icon="πŸ’Š",
38
+ layout="wide",
39
+ initial_sidebar_state="expanded"
40
+ )
41
+
42
+ # Main Title
43
+ st.title("πŸ’Š Drug Causality Classifier")
44
+ st.caption("BioBERT Model | F1 Score: 97.59% | Sensitivity: 98.68% | Specificity: 96.50%")
45
+
46
+ # Load model (cached)
47
+ classifier = load_model()
48
+
49
+ # Sidebar Configuration
50
+ st.sidebar.header("βš™οΈ Configuration")
51
+ threshold = st.sidebar.slider(
52
+ "Classification Threshold",
53
+ min_value=0.0,
54
+ max_value=1.0,
55
+ value=0.5,
56
+ step=0.05,
57
+ help="Higher threshold = stricter causality detection"
58
+ )
59
+
60
+ st.sidebar.info(
61
+ "**Threshold Guide:**\n"
62
+ "- 0.3-0.4: High sensitivity (catch all events)\n"
63
+ "- 0.5: Balanced performance\n"
64
+ "- 0.7-0.8: High precision (reduce false alarms)"
65
+ )
66
+
67
+ # Main Content
68
+ tab1, tab2, tab3 = st.tabs(["πŸ“ Single Text", "πŸ“„ PDF Analysis", "πŸ“ Batch Processing"])
69
+
70
+ # TAB 1: Single Text Classification
71
+ with tab1:
72
+ st.header("πŸ“ Single Statement Classification")
73
+ st.write("Enter medical text to classify drug-adverse event causality:")
74
+
75
+ text_input = st.text_area(
76
+ "Medical Text:",
77
+ height=150,
78
+ placeholder="e.g., Patient developed severe nausea and vomiting 2 hours after taking Drug X. Clinical assessment confirmed drug-related causality."
79
+ )
80
+
81
+ col1, col2 = st.columns([2, 1])
82
+
83
+ with col1:
84
+ if st.button("πŸ” Classify Text", type="primary", use_container_width=True):
85
+ if text_input and classifier:
86
+ with st.spinner("Analyzing text..."):
87
+ result = classifier.predict(text_input, threshold)
88
+
89
+ # Display Results
90
+ st.subheader("πŸ“Š Results")
91
+
92
+ result_col1, result_col2 = st.columns(2)
93
+
94
+ with result_col1:
95
+ classification = result['prediction'].upper()
96
+ color = "green" if result['prediction'] == 'related' else "red"
97
+ st.markdown(f"**Classification:** :{color}[{classification}]")
98
+
99
+ with result_col2:
100
+ confidence_pct = result['confidence'] * 100
101
+ st.metric("Confidence", f"{confidence_pct:.1f}%")
102
+
103
+ # Probability Distribution
104
+ st.subheader("πŸ“ˆ Probability Distribution")
105
+ probs = result['probabilities']
106
+
107
+ # Progress bars
108
+ st.write("**Related (Drug-Caused):**")
109
+ st.progress(probs['related'], text=f"{probs['related']:.2%}")
110
+
111
+ st.write("**Not Related:**")
112
+ st.progress(probs['not_related'], text=f"{probs['not_related']:.2%}")
113
+
114
+ # Raw JSON Output
115
+ with st.expander("πŸ” View Raw Results"):
116
+ st.json(result)
117
+
118
+ elif not classifier:
119
+ st.error("Model not loaded properly.")
120
+ else:
121
+ st.warning("Please enter text to classify.")
122
+
123
+ with col2:
124
+ st.info(
125
+ "**Example Inputs:**\n\n"
126
+ "**Related:** _Patient developed rash after taking aspirin. Symptoms resolved after discontinuation._\n\n"
127
+ "**Not Related:** _Patient has a history of diabetes and hypertension. Takes metformin daily._"
128
+ )
129
+
130
+ # TAB 2: PDF Analysis
131
+ with tab2:
132
+ st.header("πŸ“„ PDF Document Analysis")
133
+ st.write("Upload a PDF document for comprehensive drug-adverse event analysis:")
134
+
135
+ pdf_file = st.file_uploader(
136
+ "Choose a PDF file",
137
+ type=["pdf"],
138
+ help="Upload medical documents, case reports, or clinical notes"
139
+ )
140
+
141
+ if pdf_file and classifier:
142
+ # Save uploaded file temporarily
143
+ temp_dir = tempfile.gettempdir()
144
+ temp_path = os.path.join(temp_dir, pdf_file.name)
145
+
146
+ with open(temp_path, "wb") as tmp_f:
147
+ tmp_f.write(pdf_file.getbuffer())
148
+
149
+ # Analysis Button
150
+ if st.button("πŸ” Analyze PDF", type="primary", use_container_width=True):
151
+ with st.spinner(f"Processing {pdf_file.name}..."):
152
+ try:
153
+ # Extract and classify
154
+ pdf_text = extract_text_from_pdf(temp_path)
155
+ results = classify_causality(pdf_text, threshold=threshold)
156
+
157
+ # Display Summary
158
+ st.subheader("πŸ“Š Analysis Summary")
159
+
160
+ summary_col1, summary_col2, summary_col3 = st.columns(3)
161
+
162
+ with summary_col1:
163
+ classification = results['final_classification'].upper()
164
+ color = "green" if results['final_classification'] == 'related' else "red"
165
+ st.markdown(f"**Overall:** :{color}[{classification}]")
166
+
167
+ with summary_col2:
168
+ confidence_pct = results['confidence_score'] * 100
169
+ st.metric("Confidence", f"{confidence_pct:.1f}%")
170
+
171
+ with summary_col3:
172
+ st.metric("Total Sentences", results['total_sentences'])
173
+
174
+ # Sentence Breakdown
175
+ st.subheader("πŸ” Sentence Analysis")
176
+
177
+ breakdown_col1, breakdown_col2 = st.columns(2)
178
+
179
+ with breakdown_col1:
180
+ st.metric("Related Sentences", results['related_sentences'])
181
+
182
+ with breakdown_col2:
183
+ st.metric("Not Related", results['not_related_sentences'])
184
+
185
+ # Top Related Sentences
186
+ if results['related_sentences'] > 0:
187
+ st.subheader("🎯 Top Related Sentences")
188
+
189
+ for i, sent_detail in enumerate(results.get('top_related_sentences', []), 1):
190
+ confidence = sent_detail['probability_related']
191
+ confidence_color = "green" if confidence > 0.7 else "orange" if confidence > 0.5 else "red"
192
+
193
+ st.markdown(f"**{i}.** ({confidence:.1%} confidence)")
194
+ st.markdown(f":{confidence_color}[{sent_detail['sentence']}]")
195
+ st.write("")
196
+
197
+ # Download Button
198
+ st.subheader("πŸ’Ύ Download Report")
199
+
200
+ report_json = json.dumps(results, indent=2)
201
+
202
+ st.download_button(
203
+ label="πŸ“₯ Download JSON Report",
204
+ data=report_json,
205
+ file_name=f"{pdf_file.name}_causality_report.json",
206
+ mime="application/json"
207
+ )
208
+
209
+ # Raw Results Expander
210
+ with st.expander("πŸ” View Full Results"):
211
+ st.json(results)
212
+
213
+ except Exception as e:
214
+ st.error(f"Error processing PDF: {str(e)}")
215
+ st.info("Please ensure the PDF contains readable text and try again.")
216
+
217
+ # Clean up temp file
218
+ finally:
219
+ try:
220
+ os.remove(temp_path)
221
+ except:
222
+ pass
223
+
224
+ # TAB 3: Batch Processing
225
+ with tab3:
226
+ st.header("πŸ“ Batch PDF Processing")
227
+ st.write("Upload multiple PDF files for batch causality analysis:")
228
+
229
+ batch_files = st.file_uploader(
230
+ "Choose PDF files",
231
+ type=["pdf"],
232
+ accept_multiple_files=True,
233
+ help="Upload multiple medical documents for batch analysis"
234
+ )
235
+
236
+ if batch_files and classifier:
237
+ st.write(f"**Selected files:** {len(batch_files)} PDFs")
238
+
239
+ for i, file in enumerate(batch_files, 1):
240
+ st.write(f"{i}. {file.name}")
241
+
242
+ if st.button("πŸ” Process All PDFs", type="primary", use_container_width=True):
243
+ # Create temporary paths for all files
244
+ batch_temp_paths = []
245
+ temp_dir = tempfile.gettempdir()
246
+
247
+ try:
248
+ # Save all files temporarily
249
+ for batch_file in batch_files:
250
+ temp_path = os.path.join(temp_dir, batch_file.name)
251
+ with open(temp_path, "wb") as tmp_f:
252
+ tmp_f.write(batch_file.getbuffer())
253
+ batch_temp_paths.append(temp_path)
254
+
255
+ # Process all files
256
+ with st.spinner(f"Processing {len(batch_files)} files..."):
257
+ batch_results = process_multiple_pdfs(batch_temp_paths, threshold=threshold)
258
+
259
+ # Display Batch Summary
260
+ st.subheader("πŸ“Š Batch Analysis Summary")
261
+
262
+ # Overall stats
263
+ total_files = len(batch_results)
264
+ successful = len([r for r in batch_results if 'error' not in r])
265
+ related_count = len([r for r in batch_results if r.get('final_classification') == 'related'])
266
+
267
+ stat_col1, stat_col2, stat_col3 = st.columns(3)
268
+
269
+ with stat_col1:
270
+ st.metric("Total Files", total_files)
271
+
272
+ with stat_col2:
273
+ st.metric("Successfully Processed", successful)
274
+
275
+ with stat_col3:
276
+ st.metric("Drug-Related Files", related_count)
277
+
278
+ # Individual Results
279
+ st.subheader("πŸ“„ Individual Results")
280
+
281
+ for i, res in enumerate(batch_results, 1):
282
+ if 'error' in res:
283
+ st.error(f"**{i}. {res['pdf_file']}:** Error - {res['error']}")
284
+ else:
285
+ classification = res['final_classification'].upper()
286
+ confidence = res.get('confidence_score', 0) * 100
287
+ color = "green" if res['final_classification'] == 'related' else "red"
288
+
289
+ st.markdown(f"**{i}. {res['pdf_file']}:** :{color}[{classification}] ({confidence:.1f}% confidence)")
290
+
291
+ # Download Batch Summary
292
+ st.subheader("πŸ’Ύ Download Batch Report")
293
+
294
+ batch_report = {
295
+ 'summary': {
296
+ 'total_files': total_files,
297
+ 'successful': successful,
298
+ 'related_count': related_count,
299
+ 'threshold_used': threshold
300
+ },
301
+ 'individual_results': batch_results
302
+ }
303
+
304
+ batch_json = json.dumps(batch_report, indent=2)
305
+
306
+ st.download_button(
307
+ label="πŸ“₯ Download Batch Summary",
308
+ data=batch_json,
309
+ file_name="batch_causality_summary.json",
310
+ mime="application/json"
311
+ )
312
+
313
+ # Raw Results Expander
314
+ with st.expander("πŸ” View Full Batch Results"):
315
+ st.json(batch_results)
316
+
317
+ except Exception as e:
318
+ st.error(f"Batch processing error: {str(e)}")
319
+
320
+ finally:
321
+ # Clean up all temp files
322
+ for temp_path in batch_temp_paths:
323
+ try:
324
+ os.remove(temp_path)
325
+ except:
326
+ pass
327
+
328
+ # Footer
329
+ st.markdown("---")
330
+ st.markdown(
331
+ "**Built with BioBERT for Pharmacovigilance** | "
332
+ "Developed for clinical decision support and regulatory compliance"
333
+ )
334
+
335
+ # Sidebar additional info
336
+ st.sidebar.markdown("---")
337
+ st.sidebar.markdown("### πŸ“ˆ Model Performance")
338
+ st.sidebar.markdown(
339
+ "- **F1 Score:** 97.59%\n"
340
+ "- **Accuracy:** 97.59%\n"
341
+ "- **Sensitivity:** 98.68%\n"
342
+ "- **Specificity:** 96.50%"
343
+ )
344
+
345
+ st.sidebar.markdown("### πŸ₯ Clinical Use")
346
+ st.sidebar.markdown(
347
+ "This tool assists in:\n"
348
+ "- Adverse event detection\n"
349
+ "- Pharmacovigilance screening\n"
350
+ "- Clinical report analysis\n"
351
+ "- Regulatory compliance"
352
+ )