Sumitkumar098 commited on
Commit
7a8bc6d
·
verified ·
1 Parent(s): d186cf8

Upload 2 files

Browse files
Drug_Prediction_and_Polypharmacy_System.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
app_test.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ import pickle
5
+ import json
6
+ from transformers import AutoTokenizer, AutoModel
7
+ import torch.nn as nn
8
+ import os
9
+
10
+ # Set page config
11
+ st.set_page_config(
12
+ page_title="Drug Prediction and Polypharmacy System",
13
+ page_icon="💊",
14
+ layout="wide"
15
+ )
16
+
17
+ # Model class definition - must match the training model architecture
18
+ class EnhancedMedicationModel(nn.Module):
19
+ def __init__(self, model_name, num_medications, num_polypharmacy_classes, num_disease_classes, dropout_rate=0.3):
20
+ super().__init__()
21
+ self.bert = AutoModel.from_pretrained(model_name)
22
+ self.dropout = nn.Dropout(dropout_rate)
23
+ hidden_size = self.bert.config.hidden_size
24
+
25
+ # Common representation layer
26
+ self.common_dense = nn.Linear(hidden_size, hidden_size)
27
+
28
+ # Task-specific layers with increased complexity
29
+ # Medication prediction head (multi-label)
30
+ self.medication_classifier = nn.Sequential(
31
+ nn.Linear(hidden_size, hidden_size//2),
32
+ nn.ReLU(),
33
+ nn.Dropout(dropout_rate),
34
+ nn.Linear(hidden_size//2, num_medications)
35
+ )
36
+
37
+ # Polypharmacy risk head (multi-class)
38
+ self.polypharmacy_classifier = nn.Sequential(
39
+ nn.Linear(hidden_size, hidden_size//2),
40
+ nn.ReLU(),
41
+ nn.Dropout(dropout_rate),
42
+ nn.Linear(hidden_size//2, num_polypharmacy_classes)
43
+ )
44
+
45
+ # Disease prediction head (multi-class)
46
+ self.disease_classifier = nn.Sequential(
47
+ nn.Linear(hidden_size, hidden_size//2),
48
+ nn.ReLU(),
49
+ nn.Dropout(dropout_rate),
50
+ nn.Linear(hidden_size//2, num_disease_classes)
51
+ )
52
+
53
+ # Apply weight initialization
54
+ self._init_weights()
55
+
56
+ def _init_weights(self):
57
+ # Initialize weights for better convergence
58
+ for module in [self.medication_classifier, self.polypharmacy_classifier,
59
+ self.disease_classifier, self.common_dense]:
60
+ if isinstance(module, nn.Sequential):
61
+ for layer in module:
62
+ if isinstance(layer, nn.Linear):
63
+ nn.init.xavier_normal_(layer.weight)
64
+ nn.init.zeros_(layer.bias)
65
+ elif isinstance(module, nn.Linear):
66
+ nn.init.xavier_normal_(module.weight)
67
+ nn.init.zeros_(layer.bias)
68
+
69
+ def forward(self, input_ids, attention_mask):
70
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
71
+ pooled_output = outputs.last_hidden_state[:, 0, :] # CLS token
72
+ pooled_output = self.dropout(pooled_output)
73
+
74
+ # Common representation
75
+ common_features = torch.relu(self.common_dense(pooled_output))
76
+
77
+ medication_logits = self.medication_classifier(common_features)
78
+ polypharmacy_logits = self.polypharmacy_classifier(common_features)
79
+ disease_logits = self.disease_classifier(common_features)
80
+
81
+ return medication_logits, polypharmacy_logits, disease_logits
82
+
83
+ @st.cache_resource
84
+ def load_model_and_resources():
85
+ """Load model and necessary resources (cached for performance)"""
86
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
87
+
88
+ # Load model configuration - fixed file paths
89
+ with open('streamlit_model/model_config.json', 'r') as f:
90
+ model_config = json.load(f)
91
+
92
+ # Initialize model
93
+ model_name = model_config['model_name']
94
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
95
+
96
+ # Create model architecture
97
+ model = EnhancedMedicationModel(
98
+ model_name=model_name,
99
+ num_medications=model_config['num_medications'],
100
+ num_polypharmacy_classes=model_config['num_polypharmacy_classes'],
101
+ num_disease_classes=model_config['num_disease_classes'],
102
+ dropout_rate=0.3
103
+ )
104
+
105
+ # Load trained weights - fixed file path
106
+ model.load_state_dict(torch.load('streamlit_model/model_state_dict.pt', map_location=device))
107
+ model = model.to(device)
108
+ model.eval()
109
+
110
+ # Load encoders - fixed file path
111
+ with open('streamlit_model/label_encoders.pkl', 'rb') as f:
112
+ encoders = pickle.load(f)
113
+
114
+ # Load lookup data - fixed file path
115
+ with open('streamlit_model/lookup_data.pkl', 'rb') as f:
116
+ lookup_data = pickle.load(f)
117
+
118
+ return {
119
+ 'model': model,
120
+ 'tokenizer': tokenizer,
121
+ 'mlb': encoders['mlb'],
122
+ 'le_risk': encoders['le_risk'],
123
+ 'le_disease': encoders['le_disease'],
124
+ 'lookup_data': lookup_data,
125
+ 'device': device
126
+ }
127
+
128
+ def predict_patient_health_profile(patient_data, resources):
129
+ """
130
+ Predict health profile for a patient based on input data
131
+ """
132
+ model = resources['model']
133
+ tokenizer = resources['tokenizer']
134
+ mlb = resources['mlb']
135
+ le_risk = resources['le_risk']
136
+ le_disease = resources['le_disease']
137
+ lookup_data = resources['lookup_data']
138
+ device = resources['device']
139
+
140
+ # Create text input
141
+ text_input = f"Patient age {patient_data['age']}, gender {patient_data['gender']}, blood group {patient_data['blood_group']}, weight {patient_data['weight']}kg. " + f"SYMPTOMS: {patient_data['symptoms']}. " + f"SEVERITY: {patient_data['severity']}."
142
+
143
+ # Tokenize
144
+ encoding = tokenizer(
145
+ text_input,
146
+ add_special_tokens=True,
147
+ max_length=256,
148
+ padding='max_length',
149
+ truncation=True,
150
+ return_tensors='pt'
151
+ )
152
+
153
+ # Move to device
154
+ input_ids = encoding['input_ids'].to(device)
155
+ attention_mask = encoding['attention_mask'].to(device)
156
+
157
+ # Get predictions
158
+ with torch.no_grad():
159
+ medication_logits, polypharmacy_logits, disease_logits = model(input_ids, attention_mask)
160
+ medication_preds = torch.sigmoid(medication_logits) > 0.5
161
+ polypharmacy_pred = torch.argmax(polypharmacy_logits, dim=1)
162
+ disease_pred = torch.argmax(disease_logits, dim=1)
163
+
164
+ # Convert predictions to human-readable format
165
+ predicted_medications = mlb.classes_[medication_preds[0].cpu().numpy()]
166
+ predicted_risk = le_risk.classes_[polypharmacy_pred.item()]
167
+ predicted_disease = le_disease.classes_[disease_pred.item()]
168
+
169
+ # Get medication probabilities for all medications
170
+ medication_probs = torch.sigmoid(medication_logits).cpu().numpy()[0]
171
+ med_prob_dict = {med: prob for med, prob in zip(mlb.classes_, medication_probs)}
172
+
173
+ # Sort medications by probability
174
+ sorted_meds = sorted(med_prob_dict.items(), key=lambda x: x[1], reverse=True)
175
+ top_meds = sorted_meds[:5] # Get top 5 medications
176
+
177
+ # Format medication results
178
+ med_results = []
179
+ for i, med in enumerate(predicted_medications[:3]):
180
+ med_details = {
181
+ 'medication': med,
182
+ 'dosage': 'Consult doctor',
183
+ 'frequency': 'Consult doctor',
184
+ 'instruction': 'Consult doctor',
185
+ 'duration': 'As prescribed',
186
+ 'confidence': float(med_prob_dict[med])
187
+ }
188
+ med_results.append(med_details)
189
+
190
+ # Get disease information
191
+ disease_causes = lookup_data['disease_causes_dict'].get(predicted_disease, "Unknown causes")
192
+ disease_prevention = lookup_data['disease_prevention_dict'].get(predicted_disease, "Consult healthcare provider")
193
+
194
+ # Get polypharmacy recommendation
195
+ polypharmacy_recommendation = lookup_data['polypharmacy_recommendation_dict'].get(
196
+ predicted_risk, "Consult healthcare provider"
197
+ )
198
+
199
+ # Get personalized health tip
200
+ age_decade = (patient_data['age'] // 10) * 10
201
+ health_tip_key = (predicted_disease, age_decade, patient_data['gender'])
202
+ personalized_health_tip = lookup_data['health_tips_dict'].get(
203
+ health_tip_key, "Maintain a balanced diet and regular exercise routine."
204
+ )
205
+
206
+ # Return comprehensive results
207
+ return {
208
+ 'patient_name': patient_data['name'], # Include patient name in results
209
+ 'predicted_disease': predicted_disease,
210
+ 'disease_causes': disease_causes,
211
+ 'disease_prevention': disease_prevention,
212
+ 'medications': med_results,
213
+ 'polypharmacy_risk': predicted_risk,
214
+ 'polypharmacy_recommendation': polypharmacy_recommendation,
215
+ 'personalized_health_tips': personalized_health_tip,
216
+ 'medication_probabilities': {med: float(prob) for med, prob in top_meds}
217
+ }
218
+
219
+ def main():
220
+ # App title and description
221
+ st.title("🏥 Drug Prediction and Polypharmacy System")
222
+ st.markdown("Enter patient information to receive medication recommendations, disease prediction, and polypharmacy risk assessment.")
223
+
224
+ try:
225
+ # Load model and resources
226
+ with st.spinner("Loading medical model and resources..."):
227
+ resources = load_model_and_resources()
228
+
229
+ # Create two columns for input form
230
+ col1, col2 = st.columns(2)
231
+
232
+ # Patient information inputs
233
+ with col1:
234
+ st.subheader("Patient Information")
235
+ # Add patient name input field
236
+ name = st.text_input("Patient Name", value="John Doe")
237
+ age = st.number_input("Age", min_value=1, max_value=120, value=45)
238
+ gender = st.selectbox("Gender", options=["Male", "Female", "Other"])
239
+ blood_group = st.selectbox("Blood Group", options=["A+", "A-", "B+", "B-", "AB+", "AB-", "O+", "O-"])
240
+ weight = st.number_input("Weight (kg)", min_value=1.0, max_value=300.0, value=70.0, step=0.1)
241
+
242
+ with col2:
243
+ st.subheader("Symptoms Information")
244
+
245
+ # Common symptoms options
246
+ common_symptoms = [
247
+ "Headache", "Fever", "Fatigue", "Nausea", "Cough",
248
+ "Sore throat", "Shortness of breath", "Chest pain",
249
+ "Dizziness", "Abdominal pain", "Vomiting", "Diarrhea",
250
+ "Muscle ache", "Joint pain", "Rash", "Loss of appetite"
251
+ ]
252
+
253
+ # Use multiselect for symptoms selection
254
+ selected_symptoms = st.multiselect(
255
+ "Select Symptoms",
256
+ options=common_symptoms,
257
+ default=["Headache", "Fever", "Fatigue"]
258
+ )
259
+
260
+ # Custom symptom input
261
+ custom_symptom = st.text_input("Add other symptom (if not in list)")
262
+ if custom_symptom:
263
+ selected_symptoms.append(custom_symptom)
264
+
265
+ # Convert selected symptoms to string format as expected by the model
266
+ symptoms = "; ".join(selected_symptoms)
267
+
268
+ # More compact severity selection
269
+ st.subheader("Symptom Severity")
270
+
271
+ # Define severity levels
272
+ severity_levels = {
273
+ "Very Mild": 1,
274
+ "Mild": 2,
275
+ "Moderate": 3,
276
+ "Severe": 4,
277
+ "Very Severe": 5
278
+ }
279
+
280
+ severity_dict = {}
281
+
282
+ # Create a more compact layout with 2 columns for severity selection
283
+ if selected_symptoms:
284
+ cols = st.columns(2)
285
+ for i, symptom in enumerate(selected_symptoms):
286
+ # Alternate between columns
287
+ with cols[i % 2]:
288
+ severity_option = st.selectbox(
289
+ f"{symptom}",
290
+ options=list(severity_levels.keys()),
291
+ index=1 # Default to "Mild"
292
+ )
293
+ severity_dict[symptom] = severity_levels[severity_option]
294
+
295
+ # Convert severity dict to string format as expected by the model
296
+ severity = "; ".join([f"{symptom}:{score}" for symptom, score in severity_dict.items()])
297
+
298
+ # Submit button
299
+ if st.button("Generate Health Profile", type="primary"):
300
+ with st.spinner("Analyzing patient data and generating health profile..."):
301
+ # Prepare patient data
302
+ patient_data = {
303
+ 'name': name, # Include name in patient data
304
+ 'age': age,
305
+ 'gender': gender,
306
+ 'blood_group': blood_group,
307
+ 'weight': weight,
308
+ 'symptoms': symptoms,
309
+ 'severity': severity
310
+ }
311
+
312
+ # Get prediction
313
+ prediction = predict_patient_health_profile(patient_data, resources)
314
+
315
+ # Display results in three columns
316
+ st.subheader(f"🔍 Health Profile Analysis Results for {prediction['patient_name']}")
317
+
318
+ col1, col2, col3 = st.columns([1, 1, 1])
319
+
320
+ # Column 1: Disease information
321
+ with col1:
322
+ st.markdown("### 🦠 Disease Prediction")
323
+ st.markdown(f"**Predicted Disease**: {prediction['predicted_disease']}")
324
+
325
+ with st.expander("Disease Causes"):
326
+ st.write(prediction['disease_causes'])
327
+
328
+ with st.expander("Prevention Methods"):
329
+ st.write(prediction['disease_prevention'])
330
+
331
+ # Column 2: Medication recommendations
332
+ with col2:
333
+ st.markdown("### 💊 Medication Recommendations")
334
+ for i, med in enumerate(prediction['medications']):
335
+ st.markdown(f"**{i+1}. {med['medication']}** (Confidence: {med['confidence']:.2f})")
336
+ med_details = f"""
337
+ - **Dosage:** {med['dosage']}
338
+ - **Frequency:** {med['frequency']}
339
+ - **Instructions:** {med['instruction']}
340
+ - **Duration:** {med['duration']}
341
+ """
342
+ st.markdown(med_details)
343
+ st.divider()
344
+
345
+ # Column 3: Risk assessment and health tips
346
+ with col3:
347
+ st.markdown("### ⚠️ Polypharmacy Assessment")
348
+ risk_color = "green" if prediction['polypharmacy_risk'] == "Low" else "orange" if prediction['polypharmacy_risk'] == "Medium" else "red"
349
+ st.markdown(f"**Risk Level**: <span style='color:{risk_color};font-weight:bold;'>{prediction['polypharmacy_risk']}</span>",
350
+ unsafe_allow_html=True)
351
+ st.markdown(f"**Recommendation**: {prediction['polypharmacy_recommendation']}")
352
+
353
+ st.markdown("### 🌿 Personalized Health Tips")
354
+ st.info(prediction['personalized_health_tips'])
355
+
356
+ # Display medication probabilities as text with progress bars
357
+ st.subheader("Medication Confidence Scores")
358
+ med_names = list(prediction['medication_probabilities'].keys())
359
+ med_probs = list(prediction['medication_probabilities'].values())
360
+
361
+ # Display each medication with its confidence score as text and progress bar
362
+ for med_name, med_prob in zip(med_names, med_probs):
363
+ st.text(f"{med_name}: {med_prob:.2f}")
364
+ st.progress(med_prob)
365
+
366
+ except Exception as e:
367
+ st.error(f"An error occurred: {str(e)}")
368
+ st.error("Please make sure all model files are correctly placed in the 'streamlit_model' directory")
369
+
370
+ if __name__ == "__main__":
371
+ main()