Spaces:
Runtime error
Runtime error
Commit
Β·
0f21e9b
1
Parent(s):
1f12230
Deploy drug-causality-bert v1 with BioBERT model and caching optimizations
Browse files- app/.ipynb_checkpoints/requirements-checkpoint.txt +9 -0
- app/.ipynb_checkpoints/streamlit_app-checkpoint.py +344 -0
- app/requirements.txt +9 -0
- app/streamlit_app.py +352 -0
- models/production_model_final/config.json +25 -0
- models/production_model_final/model.safetensors +3 -0
- models/production_model_final/special_tokens_map.json +7 -0
- models/production_model_final/tokenizer.json +0 -0
- models/production_model_final/tokenizer_config.json +58 -0
- models/production_model_final/training_args.bin +3 -0
- models/production_model_final/training_config.json +16 -0
- models/production_model_final/vocab.txt +0 -0
- requirements.txt +10 -2
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-313.pyc +0 -0
- src/__pycache__/inference.cpython-313.pyc +0 -0
- src/inference.py +169 -0
- streamlit_app.py +352 -0
app/.ipynb_checkpoints/requirements-checkpoint.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.1.0
|
| 2 |
+
transformers>=4.35.0
|
| 3 |
+
pandas
|
| 4 |
+
numpy
|
| 5 |
+
scikit-learn
|
| 6 |
+
nltk>=3.7
|
| 7 |
+
PyPDF2>=3.0.1
|
| 8 |
+
streamlit>=1.22.0
|
| 9 |
+
safetensors>=0.4.0
|
app/.ipynb_checkpoints/streamlit_app-checkpoint.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import tempfile
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import nltk
|
| 7 |
+
|
| 8 |
+
nltk.download('punkt')
|
| 9 |
+
|
| 10 |
+
# Add parent directory to Python path for imports
|
| 11 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 12 |
+
|
| 13 |
+
from src.inference import CausalityClassifier, extract_text_from_pdf, classify_causality, process_pdf_file, process_multiple_pdfs
|
| 14 |
+
|
| 15 |
+
# App Configuration
|
| 16 |
+
st.set_page_config(
|
| 17 |
+
page_title="Drug Causality Classifier",
|
| 18 |
+
page_icon="π",
|
| 19 |
+
layout="wide",
|
| 20 |
+
initial_sidebar_state="expanded"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
# Main Title
|
| 24 |
+
st.title("π Drug Causality Classifier")
|
| 25 |
+
st.caption("BioBERT Model | F1 Score: 97.59% | Sensitivity: 98.68% | Specificity: 96.50%")
|
| 26 |
+
|
| 27 |
+
# Load model (cached)
|
| 28 |
+
@st.cache_resource
|
| 29 |
+
def load_model():
|
| 30 |
+
try:
|
| 31 |
+
return CausalityClassifier("models/production_model_final")
|
| 32 |
+
except Exception as e:
|
| 33 |
+
st.error(f"Failed to load model: {e}")
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
classifier = load_model()
|
| 37 |
+
|
| 38 |
+
# Sidebar Configuration
|
| 39 |
+
st.sidebar.header("βοΈ Configuration")
|
| 40 |
+
threshold = st.sidebar.slider(
|
| 41 |
+
"Classification Threshold",
|
| 42 |
+
min_value=0.0,
|
| 43 |
+
max_value=1.0,
|
| 44 |
+
value=0.5,
|
| 45 |
+
step=0.05,
|
| 46 |
+
help="Higher threshold = stricter causality detection"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
st.sidebar.info(
|
| 50 |
+
"**Threshold Guide:**\n"
|
| 51 |
+
"- 0.3-0.4: High sensitivity (catch all events)\n"
|
| 52 |
+
"- 0.5: Balanced performance\n"
|
| 53 |
+
"- 0.7-0.8: High precision (reduce false alarms)"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Main Content
|
| 57 |
+
tab1, tab2, tab3 = st.tabs(["π Single Text", "π PDF Analysis", "π Batch Processing"])
|
| 58 |
+
|
| 59 |
+
# TAB 1: Single Text Classification
|
| 60 |
+
with tab1:
|
| 61 |
+
st.header("π Single Statement Classification")
|
| 62 |
+
st.write("Enter medical text to classify drug-adverse event causality:")
|
| 63 |
+
|
| 64 |
+
text_input = st.text_area(
|
| 65 |
+
"Medical Text:",
|
| 66 |
+
height=150,
|
| 67 |
+
placeholder="e.g., Patient developed severe nausea and vomiting 2 hours after taking Drug X. Clinical assessment confirmed drug-related causality."
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
col1, col2 = st.columns([2, 1])
|
| 71 |
+
|
| 72 |
+
with col1:
|
| 73 |
+
if st.button("π Classify Text", type="primary", use_container_width=True):
|
| 74 |
+
if text_input and classifier:
|
| 75 |
+
with st.spinner("Analyzing text..."):
|
| 76 |
+
result = classifier.predict(text_input, threshold)
|
| 77 |
+
|
| 78 |
+
# Display Results
|
| 79 |
+
st.subheader("π Results")
|
| 80 |
+
|
| 81 |
+
result_col1, result_col2 = st.columns(2)
|
| 82 |
+
|
| 83 |
+
with result_col1:
|
| 84 |
+
classification = result['prediction'].upper()
|
| 85 |
+
color = "green" if result['prediction'] == 'related' else "red"
|
| 86 |
+
st.markdown(f"**Classification:** :{color}[{classification}]")
|
| 87 |
+
|
| 88 |
+
with result_col2:
|
| 89 |
+
confidence_pct = result['confidence'] * 100
|
| 90 |
+
st.metric("Confidence", f"{confidence_pct:.1f}%")
|
| 91 |
+
|
| 92 |
+
# Probability Distribution
|
| 93 |
+
st.subheader("π Probability Distribution")
|
| 94 |
+
probs = result['probabilities']
|
| 95 |
+
|
| 96 |
+
# Progress bars
|
| 97 |
+
st.write("**Related (Drug-Caused):**")
|
| 98 |
+
st.progress(probs['related'], text=f"{probs['related']:.2%}")
|
| 99 |
+
|
| 100 |
+
st.write("**Not Related:**")
|
| 101 |
+
st.progress(probs['not_related'], text=f"{probs['not_related']:.2%}")
|
| 102 |
+
|
| 103 |
+
# Raw JSON Output
|
| 104 |
+
with st.expander("π View Raw Results"):
|
| 105 |
+
st.json(result)
|
| 106 |
+
|
| 107 |
+
elif not classifier:
|
| 108 |
+
st.error("Model not loaded properly.")
|
| 109 |
+
else:
|
| 110 |
+
st.warning("Please enter text to classify.")
|
| 111 |
+
|
| 112 |
+
with col2:
|
| 113 |
+
st.info(
|
| 114 |
+
"**Example Inputs:**\n\n"
|
| 115 |
+
"**Related:** _Patient developed rash after taking aspirin. Symptoms resolved after discontinuation._\n\n"
|
| 116 |
+
"**Not Related:** _Patient has a history of diabetes and hypertension. Takes metformin daily._"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# TAB 2: PDF Analysis
|
| 120 |
+
with tab2:
|
| 121 |
+
st.header("π PDF Document Analysis")
|
| 122 |
+
st.write("Upload a PDF document for comprehensive drug-adverse event analysis:")
|
| 123 |
+
|
| 124 |
+
pdf_file = st.file_uploader(
|
| 125 |
+
"Choose a PDF file",
|
| 126 |
+
type=["pdf"],
|
| 127 |
+
help="Upload medical documents, case reports, or clinical notes"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
if pdf_file and classifier:
|
| 131 |
+
# Save uploaded file temporarily
|
| 132 |
+
temp_dir = tempfile.gettempdir()
|
| 133 |
+
temp_path = os.path.join(temp_dir, pdf_file.name)
|
| 134 |
+
|
| 135 |
+
with open(temp_path, "wb") as tmp_f:
|
| 136 |
+
tmp_f.write(pdf_file.getbuffer())
|
| 137 |
+
|
| 138 |
+
# Analysis Button
|
| 139 |
+
if st.button("π Analyze PDF", type="primary", use_container_width=True):
|
| 140 |
+
with st.spinner(f"Processing {pdf_file.name}..."):
|
| 141 |
+
try:
|
| 142 |
+
# Extract and classify
|
| 143 |
+
pdf_text = extract_text_from_pdf(temp_path)
|
| 144 |
+
results = classify_causality(pdf_text, threshold=threshold)
|
| 145 |
+
|
| 146 |
+
# Display Summary
|
| 147 |
+
st.subheader("π Analysis Summary")
|
| 148 |
+
|
| 149 |
+
summary_col1, summary_col2, summary_col3 = st.columns(3)
|
| 150 |
+
|
| 151 |
+
with summary_col1:
|
| 152 |
+
classification = results['final_classification'].upper()
|
| 153 |
+
color = "green" if results['final_classification'] == 'related' else "red"
|
| 154 |
+
st.markdown(f"**Overall:** :{color}[{classification}]")
|
| 155 |
+
|
| 156 |
+
with summary_col2:
|
| 157 |
+
confidence_pct = results['confidence_score'] * 100
|
| 158 |
+
st.metric("Confidence", f"{confidence_pct:.1f}%")
|
| 159 |
+
|
| 160 |
+
with summary_col3:
|
| 161 |
+
st.metric("Total Sentences", results['total_sentences'])
|
| 162 |
+
|
| 163 |
+
# Sentence Breakdown
|
| 164 |
+
st.subheader("π Sentence Analysis")
|
| 165 |
+
|
| 166 |
+
breakdown_col1, breakdown_col2 = st.columns(2)
|
| 167 |
+
|
| 168 |
+
with breakdown_col1:
|
| 169 |
+
st.metric("Related Sentences", results['related_sentences'])
|
| 170 |
+
|
| 171 |
+
with breakdown_col2:
|
| 172 |
+
st.metric("Not Related", results['not_related_sentences'])
|
| 173 |
+
|
| 174 |
+
# Top Related Sentences
|
| 175 |
+
if results['related_sentences'] > 0:
|
| 176 |
+
st.subheader("π― Top Related Sentences")
|
| 177 |
+
|
| 178 |
+
for i, sent_detail in enumerate(results.get('top_related_sentences', []), 1):
|
| 179 |
+
confidence = sent_detail['probability_related']
|
| 180 |
+
confidence_color = "green" if confidence > 0.7 else "orange" if confidence > 0.5 else "red"
|
| 181 |
+
|
| 182 |
+
st.markdown(f"**{i}.** ({confidence:.1%} confidence)")
|
| 183 |
+
st.markdown(f":{confidence_color}[{sent_detail['sentence']}]")
|
| 184 |
+
st.write("")
|
| 185 |
+
|
| 186 |
+
# Download Button
|
| 187 |
+
st.subheader("πΎ Download Report")
|
| 188 |
+
|
| 189 |
+
import json
|
| 190 |
+
report_json = json.dumps(results, indent=2)
|
| 191 |
+
|
| 192 |
+
st.download_button(
|
| 193 |
+
label="π₯ Download JSON Report",
|
| 194 |
+
data=report_json,
|
| 195 |
+
file_name=f"{pdf_file.name}_causality_report.json",
|
| 196 |
+
mime="application/json"
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Raw Results Expander
|
| 200 |
+
with st.expander("π View Full Results"):
|
| 201 |
+
st.json(results)
|
| 202 |
+
|
| 203 |
+
except Exception as e:
|
| 204 |
+
st.error(f"Error processing PDF: {str(e)}")
|
| 205 |
+
st.info("Please ensure the PDF contains readable text and try again.")
|
| 206 |
+
|
| 207 |
+
# Clean up temp file
|
| 208 |
+
finally:
|
| 209 |
+
try:
|
| 210 |
+
os.remove(temp_path)
|
| 211 |
+
except:
|
| 212 |
+
pass
|
| 213 |
+
|
| 214 |
+
# TAB 3: Batch Processing
|
| 215 |
+
with tab3:
|
| 216 |
+
st.header("π Batch PDF Processing")
|
| 217 |
+
st.write("Upload multiple PDF files for batch causality analysis:")
|
| 218 |
+
|
| 219 |
+
batch_files = st.file_uploader(
|
| 220 |
+
"Choose PDF files",
|
| 221 |
+
type=["pdf"],
|
| 222 |
+
accept_multiple_files=True,
|
| 223 |
+
help="Upload multiple medical documents for batch analysis"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
if batch_files and classifier:
|
| 227 |
+
st.write(f"**Selected files:** {len(batch_files)} PDFs")
|
| 228 |
+
|
| 229 |
+
for i, file in enumerate(batch_files, 1):
|
| 230 |
+
st.write(f"{i}. {file.name}")
|
| 231 |
+
|
| 232 |
+
if st.button("π Process All PDFs", type="primary", use_container_width=True):
|
| 233 |
+
# Create temporary paths for all files
|
| 234 |
+
batch_temp_paths = []
|
| 235 |
+
temp_dir = tempfile.gettempdir()
|
| 236 |
+
|
| 237 |
+
try:
|
| 238 |
+
# Save all files temporarily
|
| 239 |
+
for batch_file in batch_files:
|
| 240 |
+
temp_path = os.path.join(temp_dir, batch_file.name)
|
| 241 |
+
with open(temp_path, "wb") as tmp_f:
|
| 242 |
+
tmp_f.write(batch_file.getbuffer())
|
| 243 |
+
batch_temp_paths.append(temp_path)
|
| 244 |
+
|
| 245 |
+
# Process all files
|
| 246 |
+
with st.spinner(f"Processing {len(batch_files)} files..."):
|
| 247 |
+
batch_results = process_multiple_pdfs(batch_temp_paths, threshold=threshold)
|
| 248 |
+
|
| 249 |
+
# Display Batch Summary
|
| 250 |
+
st.subheader("π Batch Analysis Summary")
|
| 251 |
+
|
| 252 |
+
# Overall stats
|
| 253 |
+
total_files = len(batch_results)
|
| 254 |
+
successful = len([r for r in batch_results if 'error' not in r])
|
| 255 |
+
related_count = len([r for r in batch_results if r.get('final_classification') == 'related'])
|
| 256 |
+
|
| 257 |
+
stat_col1, stat_col2, stat_col3 = st.columns(3)
|
| 258 |
+
|
| 259 |
+
with stat_col1:
|
| 260 |
+
st.metric("Total Files", total_files)
|
| 261 |
+
|
| 262 |
+
with stat_col2:
|
| 263 |
+
st.metric("Successfully Processed", successful)
|
| 264 |
+
|
| 265 |
+
with stat_col3:
|
| 266 |
+
st.metric("Drug-Related Files", related_count)
|
| 267 |
+
|
| 268 |
+
# Individual Results
|
| 269 |
+
st.subheader("π Individual Results")
|
| 270 |
+
|
| 271 |
+
for i, res in enumerate(batch_results, 1):
|
| 272 |
+
if 'error' in res:
|
| 273 |
+
st.error(f"**{i}. {res['pdf_file']}:** Error - {res['error']}")
|
| 274 |
+
else:
|
| 275 |
+
classification = res['final_classification'].upper()
|
| 276 |
+
confidence = res.get('confidence_score', 0) * 100
|
| 277 |
+
color = "green" if res['final_classification'] == 'related' else "red"
|
| 278 |
+
|
| 279 |
+
st.markdown(f"**{i}. {res['pdf_file']}:** :{color}[{classification}] ({confidence:.1f}% confidence)")
|
| 280 |
+
|
| 281 |
+
# Download Batch Summary
|
| 282 |
+
st.subheader("πΎ Download Batch Report")
|
| 283 |
+
|
| 284 |
+
import json
|
| 285 |
+
batch_report = {
|
| 286 |
+
'summary': {
|
| 287 |
+
'total_files': total_files,
|
| 288 |
+
'successful': successful,
|
| 289 |
+
'related_count': related_count,
|
| 290 |
+
'threshold_used': threshold
|
| 291 |
+
},
|
| 292 |
+
'individual_results': batch_results
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
batch_json = json.dumps(batch_report, indent=2)
|
| 296 |
+
|
| 297 |
+
st.download_button(
|
| 298 |
+
label="π₯ Download Batch Summary",
|
| 299 |
+
data=batch_json,
|
| 300 |
+
file_name="batch_causality_summary.json",
|
| 301 |
+
mime="application/json"
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# Raw Results Expander
|
| 305 |
+
with st.expander("π View Full Batch Results"):
|
| 306 |
+
st.json(batch_results)
|
| 307 |
+
|
| 308 |
+
except Exception as e:
|
| 309 |
+
st.error(f"Batch processing error: {str(e)}")
|
| 310 |
+
|
| 311 |
+
finally:
|
| 312 |
+
# Clean up all temp files
|
| 313 |
+
for temp_path in batch_temp_paths:
|
| 314 |
+
try:
|
| 315 |
+
os.remove(temp_path)
|
| 316 |
+
except:
|
| 317 |
+
pass
|
| 318 |
+
|
| 319 |
+
# Footer
|
| 320 |
+
st.markdown("---")
|
| 321 |
+
st.markdown(
|
| 322 |
+
"**Built with BioBERT for Pharmacovigilance** | "
|
| 323 |
+
"Developed for clinical decision support and regulatory compliance"
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# Sidebar additional info
|
| 327 |
+
st.sidebar.markdown("---")
|
| 328 |
+
st.sidebar.markdown("### π Model Performance")
|
| 329 |
+
st.sidebar.markdown(
|
| 330 |
+
"- **F1 Score:** 97.59%\n"
|
| 331 |
+
"- **Accuracy:** 97.59%\n"
|
| 332 |
+
"- **Sensitivity:** 98.68%\n"
|
| 333 |
+
"- **Specificity:** 96.50%"
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
st.sidebar.markdown("### π₯ Clinical Use")
|
| 337 |
+
st.sidebar.markdown(
|
| 338 |
+
"This tool assists in:\n"
|
| 339 |
+
"- Adverse event detection\n"
|
| 340 |
+
"- Pharmacovigilance screening\n"
|
| 341 |
+
"- Clinical report analysis\n"
|
| 342 |
+
"- Regulatory compliance"
|
| 343 |
+
)
|
| 344 |
+
|
app/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.1.0
|
| 2 |
+
transformers>=4.35.0
|
| 3 |
+
pandas
|
| 4 |
+
numpy
|
| 5 |
+
scikit-learn
|
| 6 |
+
nltk>=3.7
|
| 7 |
+
PyPDF2>=3.0.1
|
| 8 |
+
streamlit>=1.22.0
|
| 9 |
+
safetensors>=0.4.0
|
app/streamlit_app.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import tempfile
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import nltk
|
| 9 |
+
|
| 10 |
+
nltk.download('punkt')
|
| 11 |
+
|
| 12 |
+
# Add parent directory to path
|
| 13 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 14 |
+
|
| 15 |
+
# NOW THIS IMPORT WILL WORK!
|
| 16 |
+
from src.inference import (
|
| 17 |
+
CausalityClassifier,
|
| 18 |
+
extract_text_from_pdf,
|
| 19 |
+
classify_causality,
|
| 20 |
+
process_pdf_file,
|
| 21 |
+
process_multiple_pdfs
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# SINGLE load_model function with caching
|
| 25 |
+
@st.cache_resource
|
| 26 |
+
def load_model():
|
| 27 |
+
"""Load CausalityClassifier model once and reuse across sessions"""
|
| 28 |
+
try:
|
| 29 |
+
return CausalityClassifier("models/production_model_final")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
st.error(f"Failed to load model: {e}")
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
# App Configuration
|
| 35 |
+
st.set_page_config(
|
| 36 |
+
page_title="Drug Causality Classifier",
|
| 37 |
+
page_icon="π",
|
| 38 |
+
layout="wide",
|
| 39 |
+
initial_sidebar_state="expanded"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Main Title
|
| 43 |
+
st.title("π Drug Causality Classifier")
|
| 44 |
+
st.caption("BioBERT Model | F1 Score: 97.59% | Sensitivity: 98.68% | Specificity: 96.50%")
|
| 45 |
+
|
| 46 |
+
# Load model (cached)
|
| 47 |
+
classifier = load_model()
|
| 48 |
+
|
| 49 |
+
# Sidebar Configuration
|
| 50 |
+
st.sidebar.header("βοΈ Configuration")
|
| 51 |
+
threshold = st.sidebar.slider(
|
| 52 |
+
"Classification Threshold",
|
| 53 |
+
min_value=0.0,
|
| 54 |
+
max_value=1.0,
|
| 55 |
+
value=0.5,
|
| 56 |
+
step=0.05,
|
| 57 |
+
help="Higher threshold = stricter causality detection"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
st.sidebar.info(
|
| 61 |
+
"**Threshold Guide:**\n"
|
| 62 |
+
"- 0.3-0.4: High sensitivity (catch all events)\n"
|
| 63 |
+
"- 0.5: Balanced performance\n"
|
| 64 |
+
"- 0.7-0.8: High precision (reduce false alarms)"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Main Content
|
| 68 |
+
tab1, tab2, tab3 = st.tabs(["π Single Text", "π PDF Analysis", "π Batch Processing"])
|
| 69 |
+
|
| 70 |
+
# TAB 1: Single Text Classification
|
| 71 |
+
with tab1:
|
| 72 |
+
st.header("π Single Statement Classification")
|
| 73 |
+
st.write("Enter medical text to classify drug-adverse event causality:")
|
| 74 |
+
|
| 75 |
+
text_input = st.text_area(
|
| 76 |
+
"Medical Text:",
|
| 77 |
+
height=150,
|
| 78 |
+
placeholder="e.g., Patient developed severe nausea and vomiting 2 hours after taking Drug X. Clinical assessment confirmed drug-related causality."
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
col1, col2 = st.columns([2, 1])
|
| 82 |
+
|
| 83 |
+
with col1:
|
| 84 |
+
if st.button("π Classify Text", type="primary", use_container_width=True):
|
| 85 |
+
if text_input and classifier:
|
| 86 |
+
with st.spinner("Analyzing text..."):
|
| 87 |
+
result = classifier.predict(text_input, threshold)
|
| 88 |
+
|
| 89 |
+
# Display Results
|
| 90 |
+
st.subheader("π Results")
|
| 91 |
+
|
| 92 |
+
result_col1, result_col2 = st.columns(2)
|
| 93 |
+
|
| 94 |
+
with result_col1:
|
| 95 |
+
classification = result['prediction'].upper()
|
| 96 |
+
color = "green" if result['prediction'] == 'related' else "red"
|
| 97 |
+
st.markdown(f"**Classification:** :{color}[{classification}]")
|
| 98 |
+
|
| 99 |
+
with result_col2:
|
| 100 |
+
confidence_pct = result['confidence'] * 100
|
| 101 |
+
st.metric("Confidence", f"{confidence_pct:.1f}%")
|
| 102 |
+
|
| 103 |
+
# Probability Distribution
|
| 104 |
+
st.subheader("π Probability Distribution")
|
| 105 |
+
probs = result['probabilities']
|
| 106 |
+
|
| 107 |
+
# Progress bars
|
| 108 |
+
st.write("**Related (Drug-Caused):**")
|
| 109 |
+
st.progress(probs['related'], text=f"{probs['related']:.2%}")
|
| 110 |
+
|
| 111 |
+
st.write("**Not Related:**")
|
| 112 |
+
st.progress(probs['not_related'], text=f"{probs['not_related']:.2%}")
|
| 113 |
+
|
| 114 |
+
# Raw JSON Output
|
| 115 |
+
with st.expander("π View Raw Results"):
|
| 116 |
+
st.json(result)
|
| 117 |
+
|
| 118 |
+
elif not classifier:
|
| 119 |
+
st.error("Model not loaded properly.")
|
| 120 |
+
else:
|
| 121 |
+
st.warning("Please enter text to classify.")
|
| 122 |
+
|
| 123 |
+
with col2:
|
| 124 |
+
st.info(
|
| 125 |
+
"**Example Inputs:**\n\n"
|
| 126 |
+
"**Related:** _Patient developed rash after taking aspirin. Symptoms resolved after discontinuation._\n\n"
|
| 127 |
+
"**Not Related:** _Patient has a history of diabetes and hypertension. Takes metformin daily._"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# TAB 2: PDF Analysis
|
| 131 |
+
with tab2:
|
| 132 |
+
st.header("π PDF Document Analysis")
|
| 133 |
+
st.write("Upload a PDF document for comprehensive drug-adverse event analysis:")
|
| 134 |
+
|
| 135 |
+
pdf_file = st.file_uploader(
|
| 136 |
+
"Choose a PDF file",
|
| 137 |
+
type=["pdf"],
|
| 138 |
+
help="Upload medical documents, case reports, or clinical notes"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
if pdf_file and classifier:
|
| 142 |
+
# Save uploaded file temporarily
|
| 143 |
+
temp_dir = tempfile.gettempdir()
|
| 144 |
+
temp_path = os.path.join(temp_dir, pdf_file.name)
|
| 145 |
+
|
| 146 |
+
with open(temp_path, "wb") as tmp_f:
|
| 147 |
+
tmp_f.write(pdf_file.getbuffer())
|
| 148 |
+
|
| 149 |
+
# Analysis Button
|
| 150 |
+
if st.button("π Analyze PDF", type="primary", use_container_width=True):
|
| 151 |
+
with st.spinner(f"Processing {pdf_file.name}..."):
|
| 152 |
+
try:
|
| 153 |
+
# Extract and classify
|
| 154 |
+
pdf_text = extract_text_from_pdf(temp_path)
|
| 155 |
+
results = classify_causality(pdf_text, threshold=threshold)
|
| 156 |
+
|
| 157 |
+
# Display Summary
|
| 158 |
+
st.subheader("π Analysis Summary")
|
| 159 |
+
|
| 160 |
+
summary_col1, summary_col2, summary_col3 = st.columns(3)
|
| 161 |
+
|
| 162 |
+
with summary_col1:
|
| 163 |
+
classification = results['final_classification'].upper()
|
| 164 |
+
color = "green" if results['final_classification'] == 'related' else "red"
|
| 165 |
+
st.markdown(f"**Overall:** :{color}[{classification}]")
|
| 166 |
+
|
| 167 |
+
with summary_col2:
|
| 168 |
+
confidence_pct = results['confidence_score'] * 100
|
| 169 |
+
st.metric("Confidence", f"{confidence_pct:.1f}%")
|
| 170 |
+
|
| 171 |
+
with summary_col3:
|
| 172 |
+
st.metric("Total Sentences", results['total_sentences'])
|
| 173 |
+
|
| 174 |
+
# Sentence Breakdown
|
| 175 |
+
st.subheader("π Sentence Analysis")
|
| 176 |
+
|
| 177 |
+
breakdown_col1, breakdown_col2 = st.columns(2)
|
| 178 |
+
|
| 179 |
+
with breakdown_col1:
|
| 180 |
+
st.metric("Related Sentences", results['related_sentences'])
|
| 181 |
+
|
| 182 |
+
with breakdown_col2:
|
| 183 |
+
st.metric("Not Related", results['not_related_sentences'])
|
| 184 |
+
|
| 185 |
+
# Top Related Sentences
|
| 186 |
+
if results['related_sentences'] > 0:
|
| 187 |
+
st.subheader("π― Top Related Sentences")
|
| 188 |
+
|
| 189 |
+
for i, sent_detail in enumerate(results.get('top_related_sentences', []), 1):
|
| 190 |
+
confidence = sent_detail['probability_related']
|
| 191 |
+
confidence_color = "green" if confidence > 0.7 else "orange" if confidence > 0.5 else "red"
|
| 192 |
+
|
| 193 |
+
st.markdown(f"**{i}.** ({confidence:.1%} confidence)")
|
| 194 |
+
st.markdown(f":{confidence_color}[{sent_detail['sentence']}]")
|
| 195 |
+
st.write("")
|
| 196 |
+
|
| 197 |
+
# Download Button
|
| 198 |
+
st.subheader("πΎ Download Report")
|
| 199 |
+
|
| 200 |
+
report_json = json.dumps(results, indent=2)
|
| 201 |
+
|
| 202 |
+
st.download_button(
|
| 203 |
+
label="π₯ Download JSON Report",
|
| 204 |
+
data=report_json,
|
| 205 |
+
file_name=f"{pdf_file.name}_causality_report.json",
|
| 206 |
+
mime="application/json"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Raw Results Expander
|
| 210 |
+
with st.expander("π View Full Results"):
|
| 211 |
+
st.json(results)
|
| 212 |
+
|
| 213 |
+
except Exception as e:
|
| 214 |
+
st.error(f"Error processing PDF: {str(e)}")
|
| 215 |
+
st.info("Please ensure the PDF contains readable text and try again.")
|
| 216 |
+
|
| 217 |
+
# Clean up temp file
|
| 218 |
+
finally:
|
| 219 |
+
try:
|
| 220 |
+
os.remove(temp_path)
|
| 221 |
+
except:
|
| 222 |
+
pass
|
| 223 |
+
|
| 224 |
+
# TAB 3: Batch Processing
|
| 225 |
+
with tab3:
|
| 226 |
+
st.header("π Batch PDF Processing")
|
| 227 |
+
st.write("Upload multiple PDF files for batch causality analysis:")
|
| 228 |
+
|
| 229 |
+
batch_files = st.file_uploader(
|
| 230 |
+
"Choose PDF files",
|
| 231 |
+
type=["pdf"],
|
| 232 |
+
accept_multiple_files=True,
|
| 233 |
+
help="Upload multiple medical documents for batch analysis"
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
if batch_files and classifier:
|
| 237 |
+
st.write(f"**Selected files:** {len(batch_files)} PDFs")
|
| 238 |
+
|
| 239 |
+
for i, file in enumerate(batch_files, 1):
|
| 240 |
+
st.write(f"{i}. {file.name}")
|
| 241 |
+
|
| 242 |
+
if st.button("π Process All PDFs", type="primary", use_container_width=True):
|
| 243 |
+
# Create temporary paths for all files
|
| 244 |
+
batch_temp_paths = []
|
| 245 |
+
temp_dir = tempfile.gettempdir()
|
| 246 |
+
|
| 247 |
+
try:
|
| 248 |
+
# Save all files temporarily
|
| 249 |
+
for batch_file in batch_files:
|
| 250 |
+
temp_path = os.path.join(temp_dir, batch_file.name)
|
| 251 |
+
with open(temp_path, "wb") as tmp_f:
|
| 252 |
+
tmp_f.write(batch_file.getbuffer())
|
| 253 |
+
batch_temp_paths.append(temp_path)
|
| 254 |
+
|
| 255 |
+
# Process all files
|
| 256 |
+
with st.spinner(f"Processing {len(batch_files)} files..."):
|
| 257 |
+
batch_results = process_multiple_pdfs(batch_temp_paths, threshold=threshold)
|
| 258 |
+
|
| 259 |
+
# Display Batch Summary
|
| 260 |
+
st.subheader("π Batch Analysis Summary")
|
| 261 |
+
|
| 262 |
+
# Overall stats
|
| 263 |
+
total_files = len(batch_results)
|
| 264 |
+
successful = len([r for r in batch_results if 'error' not in r])
|
| 265 |
+
related_count = len([r for r in batch_results if r.get('final_classification') == 'related'])
|
| 266 |
+
|
| 267 |
+
stat_col1, stat_col2, stat_col3 = st.columns(3)
|
| 268 |
+
|
| 269 |
+
with stat_col1:
|
| 270 |
+
st.metric("Total Files", total_files)
|
| 271 |
+
|
| 272 |
+
with stat_col2:
|
| 273 |
+
st.metric("Successfully Processed", successful)
|
| 274 |
+
|
| 275 |
+
with stat_col3:
|
| 276 |
+
st.metric("Drug-Related Files", related_count)
|
| 277 |
+
|
| 278 |
+
# Individual Results
|
| 279 |
+
st.subheader("π Individual Results")
|
| 280 |
+
|
| 281 |
+
for i, res in enumerate(batch_results, 1):
|
| 282 |
+
if 'error' in res:
|
| 283 |
+
st.error(f"**{i}. {res['pdf_file']}:** Error - {res['error']}")
|
| 284 |
+
else:
|
| 285 |
+
classification = res['final_classification'].upper()
|
| 286 |
+
confidence = res.get('confidence_score', 0) * 100
|
| 287 |
+
color = "green" if res['final_classification'] == 'related' else "red"
|
| 288 |
+
|
| 289 |
+
st.markdown(f"**{i}. {res['pdf_file']}:** :{color}[{classification}] ({confidence:.1f}% confidence)")
|
| 290 |
+
|
| 291 |
+
# Download Batch Summary
|
| 292 |
+
st.subheader("πΎ Download Batch Report")
|
| 293 |
+
|
| 294 |
+
batch_report = {
|
| 295 |
+
'summary': {
|
| 296 |
+
'total_files': total_files,
|
| 297 |
+
'successful': successful,
|
| 298 |
+
'related_count': related_count,
|
| 299 |
+
'threshold_used': threshold
|
| 300 |
+
},
|
| 301 |
+
'individual_results': batch_results
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
batch_json = json.dumps(batch_report, indent=2)
|
| 305 |
+
|
| 306 |
+
st.download_button(
|
| 307 |
+
label="π₯ Download Batch Summary",
|
| 308 |
+
data=batch_json,
|
| 309 |
+
file_name="batch_causality_summary.json",
|
| 310 |
+
mime="application/json"
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Raw Results Expander
|
| 314 |
+
with st.expander("π View Full Batch Results"):
|
| 315 |
+
st.json(batch_results)
|
| 316 |
+
|
| 317 |
+
except Exception as e:
|
| 318 |
+
st.error(f"Batch processing error: {str(e)}")
|
| 319 |
+
|
| 320 |
+
finally:
|
| 321 |
+
# Clean up all temp files
|
| 322 |
+
for temp_path in batch_temp_paths:
|
| 323 |
+
try:
|
| 324 |
+
os.remove(temp_path)
|
| 325 |
+
except:
|
| 326 |
+
pass
|
| 327 |
+
|
| 328 |
+
# Footer
|
| 329 |
+
st.markdown("---")
|
| 330 |
+
st.markdown(
|
| 331 |
+
"**Built with BioBERT for Pharmacovigilance** | "
|
| 332 |
+
"Developed for clinical decision support and regulatory compliance"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# Sidebar additional info
|
| 336 |
+
st.sidebar.markdown("---")
|
| 337 |
+
st.sidebar.markdown("### π Model Performance")
|
| 338 |
+
st.sidebar.markdown(
|
| 339 |
+
"- **F1 Score:** 97.59%\n"
|
| 340 |
+
"- **Accuracy:** 97.59%\n"
|
| 341 |
+
"- **Sensitivity:** 98.68%\n"
|
| 342 |
+
"- **Specificity:** 96.50%"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
st.sidebar.markdown("### π₯ Clinical Use")
|
| 346 |
+
st.sidebar.markdown(
|
| 347 |
+
"This tool assists in:\n"
|
| 348 |
+
"- Adverse event detection\n"
|
| 349 |
+
"- Pharmacovigilance screening\n"
|
| 350 |
+
"- Clinical report analysis\n"
|
| 351 |
+
"- Regulatory compliance"
|
| 352 |
+
)
|
models/production_model_final/config.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertForSequenceClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"classifier_dropout": null,
|
| 7 |
+
"dtype": "float32",
|
| 8 |
+
"hidden_act": "gelu",
|
| 9 |
+
"hidden_dropout_prob": 0.1,
|
| 10 |
+
"hidden_size": 768,
|
| 11 |
+
"initializer_range": 0.02,
|
| 12 |
+
"intermediate_size": 3072,
|
| 13 |
+
"layer_norm_eps": 1e-12,
|
| 14 |
+
"max_position_embeddings": 512,
|
| 15 |
+
"model_type": "bert",
|
| 16 |
+
"num_attention_heads": 12,
|
| 17 |
+
"num_hidden_layers": 12,
|
| 18 |
+
"pad_token_id": 0,
|
| 19 |
+
"position_embedding_type": "absolute",
|
| 20 |
+
"problem_type": "single_label_classification",
|
| 21 |
+
"transformers_version": "4.57.1",
|
| 22 |
+
"type_vocab_size": 2,
|
| 23 |
+
"use_cache": true,
|
| 24 |
+
"vocab_size": 30522
|
| 25 |
+
}
|
models/production_model_final/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7f73202c120a52a4288c045e5713eeecbfe7b3431b5e15dafa3ded35b8ba18e4
|
| 3 |
+
size 437958648
|
models/production_model_final/special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": "[CLS]",
|
| 3 |
+
"mask_token": "[MASK]",
|
| 4 |
+
"pad_token": "[PAD]",
|
| 5 |
+
"sep_token": "[SEP]",
|
| 6 |
+
"unk_token": "[UNK]"
|
| 7 |
+
}
|
models/production_model_final/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/production_model_final/tokenizer_config.json
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"4": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": true,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_basic_tokenize": true,
|
| 47 |
+
"do_lower_case": true,
|
| 48 |
+
"extra_special_tokens": {},
|
| 49 |
+
"mask_token": "[MASK]",
|
| 50 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 51 |
+
"never_split": null,
|
| 52 |
+
"pad_token": "[PAD]",
|
| 53 |
+
"sep_token": "[SEP]",
|
| 54 |
+
"strip_accents": null,
|
| 55 |
+
"tokenize_chinese_chars": true,
|
| 56 |
+
"tokenizer_class": "BertTokenizer",
|
| 57 |
+
"unk_token": "[UNK]"
|
| 58 |
+
}
|
models/production_model_final/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ca6a57bb3665602515473f6c3e6aa96cb5505d7b8642beb6c8604c4a00aec451
|
| 3 |
+
size 5777
|
models/production_model_final/training_config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_params": {
|
| 3 |
+
"learning_rate": 3.758180918998249e-05,
|
| 4 |
+
"num_train_epochs": 1,
|
| 5 |
+
"batch_size": 4,
|
| 6 |
+
"gradient_accumulation_steps": 4
|
| 7 |
+
},
|
| 8 |
+
"final_results": {
|
| 9 |
+
"accuracy": 0.9758909853249476,
|
| 10 |
+
"f1": 0.9758881040529953,
|
| 11 |
+
"precision": 0.9761185621296976,
|
| 12 |
+
"recall": 0.9758909853249476
|
| 13 |
+
},
|
| 14 |
+
"optuna_source": "Trial 1",
|
| 15 |
+
"training_date": "2025-10-25T16:06:34.368403"
|
| 16 |
+
}
|
models/production_model_final/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
CHANGED
|
@@ -1,3 +1,11 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
pandas
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.1.0
|
| 2 |
+
transformers>=4.35.0
|
| 3 |
pandas
|
| 4 |
+
numpy
|
| 5 |
+
scikit-learn
|
| 6 |
+
nltk>=3.7
|
| 7 |
+
PyPDF2>=3.0.1
|
| 8 |
+
streamlit>=1.22.0
|
| 9 |
+
safetensors>=0.4.0
|
| 10 |
+
pip install boto3
|
| 11 |
+
pip freeze > requirements.txt
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (147 Bytes). View file
|
|
|
src/__pycache__/inference.cpython-313.pyc
ADDED
|
Binary file (20.3 kB). View file
|
|
|
src/inference.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ο»Ώimport torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import PyPDF2
|
| 6 |
+
import json
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import Union, List, Dict
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
# NLTK with robust error handling
|
| 12 |
+
import nltk
|
| 13 |
+
import ssl
|
| 14 |
+
|
| 15 |
+
# SSL fix for NLTK
|
| 16 |
+
try:
|
| 17 |
+
_create_unverified_https_context = ssl._create_unverified_context
|
| 18 |
+
except AttributeError:
|
| 19 |
+
pass
|
| 20 |
+
else:
|
| 21 |
+
ssl._create_default_https_context = _create_unverified_https_context
|
| 22 |
+
|
| 23 |
+
# Enhanced NLTK data download with retry
|
| 24 |
+
def download_nltk_data_robust():
|
| 25 |
+
"""Download NLTK data with multiple attempts and fallbacks"""
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
# Set NLTK data path explicitly
|
| 29 |
+
nltk_data_dir = '/home/appuser/nltk_data'
|
| 30 |
+
if not os.path.exists(nltk_data_dir):
|
| 31 |
+
try:
|
| 32 |
+
os.makedirs(nltk_data_dir, exist_ok=True)
|
| 33 |
+
except:
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
if nltk_data_dir not in nltk.data.path:
|
| 37 |
+
nltk.data.path.insert(0, nltk_data_dir)
|
| 38 |
+
|
| 39 |
+
packages = ['punkt', 'punkt_tab']
|
| 40 |
+
for package in packages:
|
| 41 |
+
for attempt in range(3): # Try 3 times
|
| 42 |
+
try:
|
| 43 |
+
nltk.data.find(f'tokenizers/{package}')
|
| 44 |
+
print(f"β {package} already available")
|
| 45 |
+
break
|
| 46 |
+
except LookupError:
|
| 47 |
+
try:
|
| 48 |
+
print(f"Downloading {package} (attempt {attempt + 1})...")
|
| 49 |
+
nltk.download(package, download_dir=nltk_data_dir, quiet=False)
|
| 50 |
+
print(f"β {package} downloaded successfully")
|
| 51 |
+
break
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"Warning: Could not download {package}: {e}")
|
| 54 |
+
if attempt == 2:
|
| 55 |
+
print(f"Failed to download {package} after 3 attempts")
|
| 56 |
+
|
| 57 |
+
# Download on import
|
| 58 |
+
download_nltk_data_robust()
|
| 59 |
+
|
| 60 |
+
# Fallback sentence tokenizer using regex
|
| 61 |
+
def simple_sentence_tokenize(text):
|
| 62 |
+
"""Simple regex-based sentence tokenizer as fallback"""
|
| 63 |
+
# Split on common sentence boundaries
|
| 64 |
+
sentences = re.split(r'(?<=[.!?])\s+', text)
|
| 65 |
+
return [s.strip() for s in sentences if s.strip()]
|
| 66 |
+
|
| 67 |
+
# Safe sentence tokenization with fallback
|
| 68 |
+
def safe_sent_tokenize(text):
|
| 69 |
+
"""Tokenize with NLTK, fallback to regex if NLTK fails"""
|
| 70 |
+
try:
|
| 71 |
+
from nltk.tokenize import sent_tokenize
|
| 72 |
+
return sent_tokenize(text)
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"NLTK tokenization failed ({e}), using fallback...")
|
| 75 |
+
return simple_sentence_tokenize(text)
|
| 76 |
+
|
| 77 |
+
class CausalityClassifier:
|
| 78 |
+
def __init__(self, model_path='./models/production_model_final', threshold=0.5):
|
| 79 |
+
self.model_path = Path(model_path)
|
| 80 |
+
self.threshold = threshold
|
| 81 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
| 82 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_path)
|
| 83 |
+
self.model.eval()
|
| 84 |
+
|
| 85 |
+
def predict(self, text, return_probs=False):
|
| 86 |
+
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=96)
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
outputs = self.model(**inputs)
|
| 89 |
+
probs = torch.softmax(outputs.logits, dim=1).numpy()[0]
|
| 90 |
+
pred = 1 if probs[1] > self.threshold else 0
|
| 91 |
+
result = {
|
| 92 |
+
'prediction': 'related' if pred == 1 else 'not related',
|
| 93 |
+
'confidence': float(probs[pred]),
|
| 94 |
+
'label': int(pred)
|
| 95 |
+
}
|
| 96 |
+
if return_probs:
|
| 97 |
+
result['probabilities'] = {
|
| 98 |
+
'not_related': float(probs[0]),
|
| 99 |
+
'related': float(probs[1])
|
| 100 |
+
}
|
| 101 |
+
return result
|
| 102 |
+
|
| 103 |
+
def extract_text_from_pdf(pdf_path):
|
| 104 |
+
with open(pdf_path, 'rb') as file:
|
| 105 |
+
pdf_reader = PyPDF2.PdfReader(file)
|
| 106 |
+
text = ""
|
| 107 |
+
for page in pdf_reader.pages:
|
| 108 |
+
text += page.extract_text()
|
| 109 |
+
return text
|
| 110 |
+
|
| 111 |
+
def classify_causality(pdf_text, model_path='./models/production_model_final', threshold=0.5, verbose=False):
|
| 112 |
+
classifier = CausalityClassifier(model_path, threshold)
|
| 113 |
+
|
| 114 |
+
# Use safe tokenization with fallback
|
| 115 |
+
sentences = safe_sent_tokenize(pdf_text)
|
| 116 |
+
|
| 117 |
+
if verbose:
|
| 118 |
+
print(f"Tokenized {len(sentences)} sentences")
|
| 119 |
+
|
| 120 |
+
related_count = 0
|
| 121 |
+
sentence_details = []
|
| 122 |
+
|
| 123 |
+
for sent in sentences:
|
| 124 |
+
if not sent.strip():
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
result = classifier.predict(sent, return_probs=True)
|
| 128 |
+
if result['label'] == 1:
|
| 129 |
+
related_count += 1
|
| 130 |
+
sentence_details.append({
|
| 131 |
+
'sentence': sent[:100],
|
| 132 |
+
'probability_related': result['probabilities']['related'],
|
| 133 |
+
'confidence': result['confidence']
|
| 134 |
+
})
|
| 135 |
+
|
| 136 |
+
sentence_details.sort(key=lambda x: x['probability_related'], reverse=True)
|
| 137 |
+
|
| 138 |
+
return {
|
| 139 |
+
'final_classification': 'related' if related_count > 0 else 'not related',
|
| 140 |
+
'confidence_score': related_count / len(sentences) if sentences else 0,
|
| 141 |
+
'related_sentences': related_count,
|
| 142 |
+
'total_sentences': len(sentences),
|
| 143 |
+
'top_related_sentences': sentence_details[:5],
|
| 144 |
+
'threshold_used': threshold
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
def process_pdf_file(pdf_path, model_path='./models/production_model_final', threshold=0.5, save_report=False, output_dir='./results'):
|
| 148 |
+
pdf_text = extract_text_from_pdf(pdf_path)
|
| 149 |
+
results = classify_causality(pdf_text, model_path, threshold)
|
| 150 |
+
results['pdf_file'] = str(Path(pdf_path).name)
|
| 151 |
+
if save_report:
|
| 152 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 153 |
+
with open(Path(output_dir) / f"{Path(pdf_path).stem}_report.json", 'w') as f:
|
| 154 |
+
json.dump(results, f, indent=2)
|
| 155 |
+
return results
|
| 156 |
+
|
| 157 |
+
def process_multiple_pdfs(pdf_paths, model_path='./models/production_model_final', threshold=0.5, save_reports=False, output_dir='./results'):
|
| 158 |
+
all_results = []
|
| 159 |
+
for pdf_path in pdf_paths:
|
| 160 |
+
try:
|
| 161 |
+
results = process_pdf_file(pdf_path, model_path, threshold, save_reports, output_dir)
|
| 162 |
+
all_results.append(results)
|
| 163 |
+
except Exception as e:
|
| 164 |
+
all_results.append({
|
| 165 |
+
'pdf_file': str(Path(pdf_path).name),
|
| 166 |
+
'error': str(e),
|
| 167 |
+
'final_classification': 'error'
|
| 168 |
+
})
|
| 169 |
+
return all_results
|
streamlit_app.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import tempfile
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import nltk
|
| 9 |
+
|
| 10 |
+
nltk.download('punkt')
|
| 11 |
+
|
| 12 |
+
# Add parent directory to path
|
| 13 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 14 |
+
|
| 15 |
+
# NOW THIS IMPORT WILL WORK!
|
| 16 |
+
from src.inference import (
|
| 17 |
+
CausalityClassifier,
|
| 18 |
+
extract_text_from_pdf,
|
| 19 |
+
classify_causality,
|
| 20 |
+
process_pdf_file,
|
| 21 |
+
process_multiple_pdfs
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# SINGLE load_model function with caching
|
| 25 |
+
@st.cache_resource
|
| 26 |
+
def load_model():
|
| 27 |
+
"""Load CausalityClassifier model once and reuse across sessions"""
|
| 28 |
+
try:
|
| 29 |
+
return CausalityClassifier("models/production_model_final")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
st.error(f"Failed to load model: {e}")
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
# App Configuration
|
| 35 |
+
st.set_page_config(
|
| 36 |
+
page_title="Drug Causality Classifier",
|
| 37 |
+
page_icon="π",
|
| 38 |
+
layout="wide",
|
| 39 |
+
initial_sidebar_state="expanded"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Main Title
|
| 43 |
+
st.title("π Drug Causality Classifier")
|
| 44 |
+
st.caption("BioBERT Model | F1 Score: 97.59% | Sensitivity: 98.68% | Specificity: 96.50%")
|
| 45 |
+
|
| 46 |
+
# Load model (cached)
|
| 47 |
+
classifier = load_model()
|
| 48 |
+
|
| 49 |
+
# Sidebar Configuration
|
| 50 |
+
st.sidebar.header("βοΈ Configuration")
|
| 51 |
+
threshold = st.sidebar.slider(
|
| 52 |
+
"Classification Threshold",
|
| 53 |
+
min_value=0.0,
|
| 54 |
+
max_value=1.0,
|
| 55 |
+
value=0.5,
|
| 56 |
+
step=0.05,
|
| 57 |
+
help="Higher threshold = stricter causality detection"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
st.sidebar.info(
|
| 61 |
+
"**Threshold Guide:**\n"
|
| 62 |
+
"- 0.3-0.4: High sensitivity (catch all events)\n"
|
| 63 |
+
"- 0.5: Balanced performance\n"
|
| 64 |
+
"- 0.7-0.8: High precision (reduce false alarms)"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Main Content
|
| 68 |
+
tab1, tab2, tab3 = st.tabs(["π Single Text", "π PDF Analysis", "π Batch Processing"])
|
| 69 |
+
|
| 70 |
+
# TAB 1: Single Text Classification
|
| 71 |
+
with tab1:
|
| 72 |
+
st.header("π Single Statement Classification")
|
| 73 |
+
st.write("Enter medical text to classify drug-adverse event causality:")
|
| 74 |
+
|
| 75 |
+
text_input = st.text_area(
|
| 76 |
+
"Medical Text:",
|
| 77 |
+
height=150,
|
| 78 |
+
placeholder="e.g., Patient developed severe nausea and vomiting 2 hours after taking Drug X. Clinical assessment confirmed drug-related causality."
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
col1, col2 = st.columns([2, 1])
|
| 82 |
+
|
| 83 |
+
with col1:
|
| 84 |
+
if st.button("π Classify Text", type="primary", use_container_width=True):
|
| 85 |
+
if text_input and classifier:
|
| 86 |
+
with st.spinner("Analyzing text..."):
|
| 87 |
+
result = classifier.predict(text_input, threshold)
|
| 88 |
+
|
| 89 |
+
# Display Results
|
| 90 |
+
st.subheader("π Results")
|
| 91 |
+
|
| 92 |
+
result_col1, result_col2 = st.columns(2)
|
| 93 |
+
|
| 94 |
+
with result_col1:
|
| 95 |
+
classification = result['prediction'].upper()
|
| 96 |
+
color = "green" if result['prediction'] == 'related' else "red"
|
| 97 |
+
st.markdown(f"**Classification:** :{color}[{classification}]")
|
| 98 |
+
|
| 99 |
+
with result_col2:
|
| 100 |
+
confidence_pct = result['confidence'] * 100
|
| 101 |
+
st.metric("Confidence", f"{confidence_pct:.1f}%")
|
| 102 |
+
|
| 103 |
+
# Probability Distribution
|
| 104 |
+
st.subheader("π Probability Distribution")
|
| 105 |
+
probs = result['probabilities']
|
| 106 |
+
|
| 107 |
+
# Progress bars
|
| 108 |
+
st.write("**Related (Drug-Caused):**")
|
| 109 |
+
st.progress(probs['related'], text=f"{probs['related']:.2%}")
|
| 110 |
+
|
| 111 |
+
st.write("**Not Related:**")
|
| 112 |
+
st.progress(probs['not_related'], text=f"{probs['not_related']:.2%}")
|
| 113 |
+
|
| 114 |
+
# Raw JSON Output
|
| 115 |
+
with st.expander("π View Raw Results"):
|
| 116 |
+
st.json(result)
|
| 117 |
+
|
| 118 |
+
elif not classifier:
|
| 119 |
+
st.error("Model not loaded properly.")
|
| 120 |
+
else:
|
| 121 |
+
st.warning("Please enter text to classify.")
|
| 122 |
+
|
| 123 |
+
with col2:
|
| 124 |
+
st.info(
|
| 125 |
+
"**Example Inputs:**\n\n"
|
| 126 |
+
"**Related:** _Patient developed rash after taking aspirin. Symptoms resolved after discontinuation._\n\n"
|
| 127 |
+
"**Not Related:** _Patient has a history of diabetes and hypertension. Takes metformin daily._"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# TAB 2: PDF Analysis
|
| 131 |
+
with tab2:
|
| 132 |
+
st.header("π PDF Document Analysis")
|
| 133 |
+
st.write("Upload a PDF document for comprehensive drug-adverse event analysis:")
|
| 134 |
+
|
| 135 |
+
pdf_file = st.file_uploader(
|
| 136 |
+
"Choose a PDF file",
|
| 137 |
+
type=["pdf"],
|
| 138 |
+
help="Upload medical documents, case reports, or clinical notes"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
if pdf_file and classifier:
|
| 142 |
+
# Save uploaded file temporarily
|
| 143 |
+
temp_dir = tempfile.gettempdir()
|
| 144 |
+
temp_path = os.path.join(temp_dir, pdf_file.name)
|
| 145 |
+
|
| 146 |
+
with open(temp_path, "wb") as tmp_f:
|
| 147 |
+
tmp_f.write(pdf_file.getbuffer())
|
| 148 |
+
|
| 149 |
+
# Analysis Button
|
| 150 |
+
if st.button("π Analyze PDF", type="primary", use_container_width=True):
|
| 151 |
+
with st.spinner(f"Processing {pdf_file.name}..."):
|
| 152 |
+
try:
|
| 153 |
+
# Extract and classify
|
| 154 |
+
pdf_text = extract_text_from_pdf(temp_path)
|
| 155 |
+
results = classify_causality(pdf_text, threshold=threshold)
|
| 156 |
+
|
| 157 |
+
# Display Summary
|
| 158 |
+
st.subheader("π Analysis Summary")
|
| 159 |
+
|
| 160 |
+
summary_col1, summary_col2, summary_col3 = st.columns(3)
|
| 161 |
+
|
| 162 |
+
with summary_col1:
|
| 163 |
+
classification = results['final_classification'].upper()
|
| 164 |
+
color = "green" if results['final_classification'] == 'related' else "red"
|
| 165 |
+
st.markdown(f"**Overall:** :{color}[{classification}]")
|
| 166 |
+
|
| 167 |
+
with summary_col2:
|
| 168 |
+
confidence_pct = results['confidence_score'] * 100
|
| 169 |
+
st.metric("Confidence", f"{confidence_pct:.1f}%")
|
| 170 |
+
|
| 171 |
+
with summary_col3:
|
| 172 |
+
st.metric("Total Sentences", results['total_sentences'])
|
| 173 |
+
|
| 174 |
+
# Sentence Breakdown
|
| 175 |
+
st.subheader("π Sentence Analysis")
|
| 176 |
+
|
| 177 |
+
breakdown_col1, breakdown_col2 = st.columns(2)
|
| 178 |
+
|
| 179 |
+
with breakdown_col1:
|
| 180 |
+
st.metric("Related Sentences", results['related_sentences'])
|
| 181 |
+
|
| 182 |
+
with breakdown_col2:
|
| 183 |
+
st.metric("Not Related", results['not_related_sentences'])
|
| 184 |
+
|
| 185 |
+
# Top Related Sentences
|
| 186 |
+
if results['related_sentences'] > 0:
|
| 187 |
+
st.subheader("π― Top Related Sentences")
|
| 188 |
+
|
| 189 |
+
for i, sent_detail in enumerate(results.get('top_related_sentences', []), 1):
|
| 190 |
+
confidence = sent_detail['probability_related']
|
| 191 |
+
confidence_color = "green" if confidence > 0.7 else "orange" if confidence > 0.5 else "red"
|
| 192 |
+
|
| 193 |
+
st.markdown(f"**{i}.** ({confidence:.1%} confidence)")
|
| 194 |
+
st.markdown(f":{confidence_color}[{sent_detail['sentence']}]")
|
| 195 |
+
st.write("")
|
| 196 |
+
|
| 197 |
+
# Download Button
|
| 198 |
+
st.subheader("πΎ Download Report")
|
| 199 |
+
|
| 200 |
+
report_json = json.dumps(results, indent=2)
|
| 201 |
+
|
| 202 |
+
st.download_button(
|
| 203 |
+
label="π₯ Download JSON Report",
|
| 204 |
+
data=report_json,
|
| 205 |
+
file_name=f"{pdf_file.name}_causality_report.json",
|
| 206 |
+
mime="application/json"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Raw Results Expander
|
| 210 |
+
with st.expander("π View Full Results"):
|
| 211 |
+
st.json(results)
|
| 212 |
+
|
| 213 |
+
except Exception as e:
|
| 214 |
+
st.error(f"Error processing PDF: {str(e)}")
|
| 215 |
+
st.info("Please ensure the PDF contains readable text and try again.")
|
| 216 |
+
|
| 217 |
+
# Clean up temp file
|
| 218 |
+
finally:
|
| 219 |
+
try:
|
| 220 |
+
os.remove(temp_path)
|
| 221 |
+
except:
|
| 222 |
+
pass
|
| 223 |
+
|
| 224 |
+
# TAB 3: Batch Processing
|
| 225 |
+
with tab3:
|
| 226 |
+
st.header("π Batch PDF Processing")
|
| 227 |
+
st.write("Upload multiple PDF files for batch causality analysis:")
|
| 228 |
+
|
| 229 |
+
batch_files = st.file_uploader(
|
| 230 |
+
"Choose PDF files",
|
| 231 |
+
type=["pdf"],
|
| 232 |
+
accept_multiple_files=True,
|
| 233 |
+
help="Upload multiple medical documents for batch analysis"
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
if batch_files and classifier:
|
| 237 |
+
st.write(f"**Selected files:** {len(batch_files)} PDFs")
|
| 238 |
+
|
| 239 |
+
for i, file in enumerate(batch_files, 1):
|
| 240 |
+
st.write(f"{i}. {file.name}")
|
| 241 |
+
|
| 242 |
+
if st.button("π Process All PDFs", type="primary", use_container_width=True):
|
| 243 |
+
# Create temporary paths for all files
|
| 244 |
+
batch_temp_paths = []
|
| 245 |
+
temp_dir = tempfile.gettempdir()
|
| 246 |
+
|
| 247 |
+
try:
|
| 248 |
+
# Save all files temporarily
|
| 249 |
+
for batch_file in batch_files:
|
| 250 |
+
temp_path = os.path.join(temp_dir, batch_file.name)
|
| 251 |
+
with open(temp_path, "wb") as tmp_f:
|
| 252 |
+
tmp_f.write(batch_file.getbuffer())
|
| 253 |
+
batch_temp_paths.append(temp_path)
|
| 254 |
+
|
| 255 |
+
# Process all files
|
| 256 |
+
with st.spinner(f"Processing {len(batch_files)} files..."):
|
| 257 |
+
batch_results = process_multiple_pdfs(batch_temp_paths, threshold=threshold)
|
| 258 |
+
|
| 259 |
+
# Display Batch Summary
|
| 260 |
+
st.subheader("π Batch Analysis Summary")
|
| 261 |
+
|
| 262 |
+
# Overall stats
|
| 263 |
+
total_files = len(batch_results)
|
| 264 |
+
successful = len([r for r in batch_results if 'error' not in r])
|
| 265 |
+
related_count = len([r for r in batch_results if r.get('final_classification') == 'related'])
|
| 266 |
+
|
| 267 |
+
stat_col1, stat_col2, stat_col3 = st.columns(3)
|
| 268 |
+
|
| 269 |
+
with stat_col1:
|
| 270 |
+
st.metric("Total Files", total_files)
|
| 271 |
+
|
| 272 |
+
with stat_col2:
|
| 273 |
+
st.metric("Successfully Processed", successful)
|
| 274 |
+
|
| 275 |
+
with stat_col3:
|
| 276 |
+
st.metric("Drug-Related Files", related_count)
|
| 277 |
+
|
| 278 |
+
# Individual Results
|
| 279 |
+
st.subheader("π Individual Results")
|
| 280 |
+
|
| 281 |
+
for i, res in enumerate(batch_results, 1):
|
| 282 |
+
if 'error' in res:
|
| 283 |
+
st.error(f"**{i}. {res['pdf_file']}:** Error - {res['error']}")
|
| 284 |
+
else:
|
| 285 |
+
classification = res['final_classification'].upper()
|
| 286 |
+
confidence = res.get('confidence_score', 0) * 100
|
| 287 |
+
color = "green" if res['final_classification'] == 'related' else "red"
|
| 288 |
+
|
| 289 |
+
st.markdown(f"**{i}. {res['pdf_file']}:** :{color}[{classification}] ({confidence:.1f}% confidence)")
|
| 290 |
+
|
| 291 |
+
# Download Batch Summary
|
| 292 |
+
st.subheader("πΎ Download Batch Report")
|
| 293 |
+
|
| 294 |
+
batch_report = {
|
| 295 |
+
'summary': {
|
| 296 |
+
'total_files': total_files,
|
| 297 |
+
'successful': successful,
|
| 298 |
+
'related_count': related_count,
|
| 299 |
+
'threshold_used': threshold
|
| 300 |
+
},
|
| 301 |
+
'individual_results': batch_results
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
batch_json = json.dumps(batch_report, indent=2)
|
| 305 |
+
|
| 306 |
+
st.download_button(
|
| 307 |
+
label="π₯ Download Batch Summary",
|
| 308 |
+
data=batch_json,
|
| 309 |
+
file_name="batch_causality_summary.json",
|
| 310 |
+
mime="application/json"
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Raw Results Expander
|
| 314 |
+
with st.expander("π View Full Batch Results"):
|
| 315 |
+
st.json(batch_results)
|
| 316 |
+
|
| 317 |
+
except Exception as e:
|
| 318 |
+
st.error(f"Batch processing error: {str(e)}")
|
| 319 |
+
|
| 320 |
+
finally:
|
| 321 |
+
# Clean up all temp files
|
| 322 |
+
for temp_path in batch_temp_paths:
|
| 323 |
+
try:
|
| 324 |
+
os.remove(temp_path)
|
| 325 |
+
except:
|
| 326 |
+
pass
|
| 327 |
+
|
| 328 |
+
# Footer
|
| 329 |
+
st.markdown("---")
|
| 330 |
+
st.markdown(
|
| 331 |
+
"**Built with BioBERT for Pharmacovigilance** | "
|
| 332 |
+
"Developed for clinical decision support and regulatory compliance"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# Sidebar additional info
|
| 336 |
+
st.sidebar.markdown("---")
|
| 337 |
+
st.sidebar.markdown("### π Model Performance")
|
| 338 |
+
st.sidebar.markdown(
|
| 339 |
+
"- **F1 Score:** 97.59%\n"
|
| 340 |
+
"- **Accuracy:** 97.59%\n"
|
| 341 |
+
"- **Sensitivity:** 98.68%\n"
|
| 342 |
+
"- **Specificity:** 96.50%"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
st.sidebar.markdown("### π₯ Clinical Use")
|
| 346 |
+
st.sidebar.markdown(
|
| 347 |
+
"This tool assists in:\n"
|
| 348 |
+
"- Adverse event detection\n"
|
| 349 |
+
"- Pharmacovigilance screening\n"
|
| 350 |
+
"- Clinical report analysis\n"
|
| 351 |
+
"- Regulatory compliance"
|
| 352 |
+
)
|