Upload 2 files
Browse files- Drug_Prediction_and_Polypharmacy_System.ipynb +0 -0
- app_test.py +371 -0
Drug_Prediction_and_Polypharmacy_System.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app_test.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pickle
|
| 5 |
+
import json
|
| 6 |
+
from transformers import AutoTokenizer, AutoModel
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
# Set page config
|
| 11 |
+
st.set_page_config(
|
| 12 |
+
page_title="Drug Prediction and Polypharmacy System",
|
| 13 |
+
page_icon="💊",
|
| 14 |
+
layout="wide"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# Model class definition - must match the training model architecture
|
| 18 |
+
class EnhancedMedicationModel(nn.Module):
|
| 19 |
+
def __init__(self, model_name, num_medications, num_polypharmacy_classes, num_disease_classes, dropout_rate=0.3):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.bert = AutoModel.from_pretrained(model_name)
|
| 22 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 23 |
+
hidden_size = self.bert.config.hidden_size
|
| 24 |
+
|
| 25 |
+
# Common representation layer
|
| 26 |
+
self.common_dense = nn.Linear(hidden_size, hidden_size)
|
| 27 |
+
|
| 28 |
+
# Task-specific layers with increased complexity
|
| 29 |
+
# Medication prediction head (multi-label)
|
| 30 |
+
self.medication_classifier = nn.Sequential(
|
| 31 |
+
nn.Linear(hidden_size, hidden_size//2),
|
| 32 |
+
nn.ReLU(),
|
| 33 |
+
nn.Dropout(dropout_rate),
|
| 34 |
+
nn.Linear(hidden_size//2, num_medications)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Polypharmacy risk head (multi-class)
|
| 38 |
+
self.polypharmacy_classifier = nn.Sequential(
|
| 39 |
+
nn.Linear(hidden_size, hidden_size//2),
|
| 40 |
+
nn.ReLU(),
|
| 41 |
+
nn.Dropout(dropout_rate),
|
| 42 |
+
nn.Linear(hidden_size//2, num_polypharmacy_classes)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Disease prediction head (multi-class)
|
| 46 |
+
self.disease_classifier = nn.Sequential(
|
| 47 |
+
nn.Linear(hidden_size, hidden_size//2),
|
| 48 |
+
nn.ReLU(),
|
| 49 |
+
nn.Dropout(dropout_rate),
|
| 50 |
+
nn.Linear(hidden_size//2, num_disease_classes)
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Apply weight initialization
|
| 54 |
+
self._init_weights()
|
| 55 |
+
|
| 56 |
+
def _init_weights(self):
|
| 57 |
+
# Initialize weights for better convergence
|
| 58 |
+
for module in [self.medication_classifier, self.polypharmacy_classifier,
|
| 59 |
+
self.disease_classifier, self.common_dense]:
|
| 60 |
+
if isinstance(module, nn.Sequential):
|
| 61 |
+
for layer in module:
|
| 62 |
+
if isinstance(layer, nn.Linear):
|
| 63 |
+
nn.init.xavier_normal_(layer.weight)
|
| 64 |
+
nn.init.zeros_(layer.bias)
|
| 65 |
+
elif isinstance(module, nn.Linear):
|
| 66 |
+
nn.init.xavier_normal_(module.weight)
|
| 67 |
+
nn.init.zeros_(layer.bias)
|
| 68 |
+
|
| 69 |
+
def forward(self, input_ids, attention_mask):
|
| 70 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| 71 |
+
pooled_output = outputs.last_hidden_state[:, 0, :] # CLS token
|
| 72 |
+
pooled_output = self.dropout(pooled_output)
|
| 73 |
+
|
| 74 |
+
# Common representation
|
| 75 |
+
common_features = torch.relu(self.common_dense(pooled_output))
|
| 76 |
+
|
| 77 |
+
medication_logits = self.medication_classifier(common_features)
|
| 78 |
+
polypharmacy_logits = self.polypharmacy_classifier(common_features)
|
| 79 |
+
disease_logits = self.disease_classifier(common_features)
|
| 80 |
+
|
| 81 |
+
return medication_logits, polypharmacy_logits, disease_logits
|
| 82 |
+
|
| 83 |
+
@st.cache_resource
|
| 84 |
+
def load_model_and_resources():
|
| 85 |
+
"""Load model and necessary resources (cached for performance)"""
|
| 86 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 87 |
+
|
| 88 |
+
# Load model configuration - fixed file paths
|
| 89 |
+
with open('streamlit_model/model_config.json', 'r') as f:
|
| 90 |
+
model_config = json.load(f)
|
| 91 |
+
|
| 92 |
+
# Initialize model
|
| 93 |
+
model_name = model_config['model_name']
|
| 94 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 95 |
+
|
| 96 |
+
# Create model architecture
|
| 97 |
+
model = EnhancedMedicationModel(
|
| 98 |
+
model_name=model_name,
|
| 99 |
+
num_medications=model_config['num_medications'],
|
| 100 |
+
num_polypharmacy_classes=model_config['num_polypharmacy_classes'],
|
| 101 |
+
num_disease_classes=model_config['num_disease_classes'],
|
| 102 |
+
dropout_rate=0.3
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Load trained weights - fixed file path
|
| 106 |
+
model.load_state_dict(torch.load('streamlit_model/model_state_dict.pt', map_location=device))
|
| 107 |
+
model = model.to(device)
|
| 108 |
+
model.eval()
|
| 109 |
+
|
| 110 |
+
# Load encoders - fixed file path
|
| 111 |
+
with open('streamlit_model/label_encoders.pkl', 'rb') as f:
|
| 112 |
+
encoders = pickle.load(f)
|
| 113 |
+
|
| 114 |
+
# Load lookup data - fixed file path
|
| 115 |
+
with open('streamlit_model/lookup_data.pkl', 'rb') as f:
|
| 116 |
+
lookup_data = pickle.load(f)
|
| 117 |
+
|
| 118 |
+
return {
|
| 119 |
+
'model': model,
|
| 120 |
+
'tokenizer': tokenizer,
|
| 121 |
+
'mlb': encoders['mlb'],
|
| 122 |
+
'le_risk': encoders['le_risk'],
|
| 123 |
+
'le_disease': encoders['le_disease'],
|
| 124 |
+
'lookup_data': lookup_data,
|
| 125 |
+
'device': device
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
def predict_patient_health_profile(patient_data, resources):
|
| 129 |
+
"""
|
| 130 |
+
Predict health profile for a patient based on input data
|
| 131 |
+
"""
|
| 132 |
+
model = resources['model']
|
| 133 |
+
tokenizer = resources['tokenizer']
|
| 134 |
+
mlb = resources['mlb']
|
| 135 |
+
le_risk = resources['le_risk']
|
| 136 |
+
le_disease = resources['le_disease']
|
| 137 |
+
lookup_data = resources['lookup_data']
|
| 138 |
+
device = resources['device']
|
| 139 |
+
|
| 140 |
+
# Create text input
|
| 141 |
+
text_input = f"Patient age {patient_data['age']}, gender {patient_data['gender']}, blood group {patient_data['blood_group']}, weight {patient_data['weight']}kg. " + f"SYMPTOMS: {patient_data['symptoms']}. " + f"SEVERITY: {patient_data['severity']}."
|
| 142 |
+
|
| 143 |
+
# Tokenize
|
| 144 |
+
encoding = tokenizer(
|
| 145 |
+
text_input,
|
| 146 |
+
add_special_tokens=True,
|
| 147 |
+
max_length=256,
|
| 148 |
+
padding='max_length',
|
| 149 |
+
truncation=True,
|
| 150 |
+
return_tensors='pt'
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Move to device
|
| 154 |
+
input_ids = encoding['input_ids'].to(device)
|
| 155 |
+
attention_mask = encoding['attention_mask'].to(device)
|
| 156 |
+
|
| 157 |
+
# Get predictions
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
medication_logits, polypharmacy_logits, disease_logits = model(input_ids, attention_mask)
|
| 160 |
+
medication_preds = torch.sigmoid(medication_logits) > 0.5
|
| 161 |
+
polypharmacy_pred = torch.argmax(polypharmacy_logits, dim=1)
|
| 162 |
+
disease_pred = torch.argmax(disease_logits, dim=1)
|
| 163 |
+
|
| 164 |
+
# Convert predictions to human-readable format
|
| 165 |
+
predicted_medications = mlb.classes_[medication_preds[0].cpu().numpy()]
|
| 166 |
+
predicted_risk = le_risk.classes_[polypharmacy_pred.item()]
|
| 167 |
+
predicted_disease = le_disease.classes_[disease_pred.item()]
|
| 168 |
+
|
| 169 |
+
# Get medication probabilities for all medications
|
| 170 |
+
medication_probs = torch.sigmoid(medication_logits).cpu().numpy()[0]
|
| 171 |
+
med_prob_dict = {med: prob for med, prob in zip(mlb.classes_, medication_probs)}
|
| 172 |
+
|
| 173 |
+
# Sort medications by probability
|
| 174 |
+
sorted_meds = sorted(med_prob_dict.items(), key=lambda x: x[1], reverse=True)
|
| 175 |
+
top_meds = sorted_meds[:5] # Get top 5 medications
|
| 176 |
+
|
| 177 |
+
# Format medication results
|
| 178 |
+
med_results = []
|
| 179 |
+
for i, med in enumerate(predicted_medications[:3]):
|
| 180 |
+
med_details = {
|
| 181 |
+
'medication': med,
|
| 182 |
+
'dosage': 'Consult doctor',
|
| 183 |
+
'frequency': 'Consult doctor',
|
| 184 |
+
'instruction': 'Consult doctor',
|
| 185 |
+
'duration': 'As prescribed',
|
| 186 |
+
'confidence': float(med_prob_dict[med])
|
| 187 |
+
}
|
| 188 |
+
med_results.append(med_details)
|
| 189 |
+
|
| 190 |
+
# Get disease information
|
| 191 |
+
disease_causes = lookup_data['disease_causes_dict'].get(predicted_disease, "Unknown causes")
|
| 192 |
+
disease_prevention = lookup_data['disease_prevention_dict'].get(predicted_disease, "Consult healthcare provider")
|
| 193 |
+
|
| 194 |
+
# Get polypharmacy recommendation
|
| 195 |
+
polypharmacy_recommendation = lookup_data['polypharmacy_recommendation_dict'].get(
|
| 196 |
+
predicted_risk, "Consult healthcare provider"
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Get personalized health tip
|
| 200 |
+
age_decade = (patient_data['age'] // 10) * 10
|
| 201 |
+
health_tip_key = (predicted_disease, age_decade, patient_data['gender'])
|
| 202 |
+
personalized_health_tip = lookup_data['health_tips_dict'].get(
|
| 203 |
+
health_tip_key, "Maintain a balanced diet and regular exercise routine."
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Return comprehensive results
|
| 207 |
+
return {
|
| 208 |
+
'patient_name': patient_data['name'], # Include patient name in results
|
| 209 |
+
'predicted_disease': predicted_disease,
|
| 210 |
+
'disease_causes': disease_causes,
|
| 211 |
+
'disease_prevention': disease_prevention,
|
| 212 |
+
'medications': med_results,
|
| 213 |
+
'polypharmacy_risk': predicted_risk,
|
| 214 |
+
'polypharmacy_recommendation': polypharmacy_recommendation,
|
| 215 |
+
'personalized_health_tips': personalized_health_tip,
|
| 216 |
+
'medication_probabilities': {med: float(prob) for med, prob in top_meds}
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
def main():
|
| 220 |
+
# App title and description
|
| 221 |
+
st.title("🏥 Drug Prediction and Polypharmacy System")
|
| 222 |
+
st.markdown("Enter patient information to receive medication recommendations, disease prediction, and polypharmacy risk assessment.")
|
| 223 |
+
|
| 224 |
+
try:
|
| 225 |
+
# Load model and resources
|
| 226 |
+
with st.spinner("Loading medical model and resources..."):
|
| 227 |
+
resources = load_model_and_resources()
|
| 228 |
+
|
| 229 |
+
# Create two columns for input form
|
| 230 |
+
col1, col2 = st.columns(2)
|
| 231 |
+
|
| 232 |
+
# Patient information inputs
|
| 233 |
+
with col1:
|
| 234 |
+
st.subheader("Patient Information")
|
| 235 |
+
# Add patient name input field
|
| 236 |
+
name = st.text_input("Patient Name", value="John Doe")
|
| 237 |
+
age = st.number_input("Age", min_value=1, max_value=120, value=45)
|
| 238 |
+
gender = st.selectbox("Gender", options=["Male", "Female", "Other"])
|
| 239 |
+
blood_group = st.selectbox("Blood Group", options=["A+", "A-", "B+", "B-", "AB+", "AB-", "O+", "O-"])
|
| 240 |
+
weight = st.number_input("Weight (kg)", min_value=1.0, max_value=300.0, value=70.0, step=0.1)
|
| 241 |
+
|
| 242 |
+
with col2:
|
| 243 |
+
st.subheader("Symptoms Information")
|
| 244 |
+
|
| 245 |
+
# Common symptoms options
|
| 246 |
+
common_symptoms = [
|
| 247 |
+
"Headache", "Fever", "Fatigue", "Nausea", "Cough",
|
| 248 |
+
"Sore throat", "Shortness of breath", "Chest pain",
|
| 249 |
+
"Dizziness", "Abdominal pain", "Vomiting", "Diarrhea",
|
| 250 |
+
"Muscle ache", "Joint pain", "Rash", "Loss of appetite"
|
| 251 |
+
]
|
| 252 |
+
|
| 253 |
+
# Use multiselect for symptoms selection
|
| 254 |
+
selected_symptoms = st.multiselect(
|
| 255 |
+
"Select Symptoms",
|
| 256 |
+
options=common_symptoms,
|
| 257 |
+
default=["Headache", "Fever", "Fatigue"]
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# Custom symptom input
|
| 261 |
+
custom_symptom = st.text_input("Add other symptom (if not in list)")
|
| 262 |
+
if custom_symptom:
|
| 263 |
+
selected_symptoms.append(custom_symptom)
|
| 264 |
+
|
| 265 |
+
# Convert selected symptoms to string format as expected by the model
|
| 266 |
+
symptoms = "; ".join(selected_symptoms)
|
| 267 |
+
|
| 268 |
+
# More compact severity selection
|
| 269 |
+
st.subheader("Symptom Severity")
|
| 270 |
+
|
| 271 |
+
# Define severity levels
|
| 272 |
+
severity_levels = {
|
| 273 |
+
"Very Mild": 1,
|
| 274 |
+
"Mild": 2,
|
| 275 |
+
"Moderate": 3,
|
| 276 |
+
"Severe": 4,
|
| 277 |
+
"Very Severe": 5
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
severity_dict = {}
|
| 281 |
+
|
| 282 |
+
# Create a more compact layout with 2 columns for severity selection
|
| 283 |
+
if selected_symptoms:
|
| 284 |
+
cols = st.columns(2)
|
| 285 |
+
for i, symptom in enumerate(selected_symptoms):
|
| 286 |
+
# Alternate between columns
|
| 287 |
+
with cols[i % 2]:
|
| 288 |
+
severity_option = st.selectbox(
|
| 289 |
+
f"{symptom}",
|
| 290 |
+
options=list(severity_levels.keys()),
|
| 291 |
+
index=1 # Default to "Mild"
|
| 292 |
+
)
|
| 293 |
+
severity_dict[symptom] = severity_levels[severity_option]
|
| 294 |
+
|
| 295 |
+
# Convert severity dict to string format as expected by the model
|
| 296 |
+
severity = "; ".join([f"{symptom}:{score}" for symptom, score in severity_dict.items()])
|
| 297 |
+
|
| 298 |
+
# Submit button
|
| 299 |
+
if st.button("Generate Health Profile", type="primary"):
|
| 300 |
+
with st.spinner("Analyzing patient data and generating health profile..."):
|
| 301 |
+
# Prepare patient data
|
| 302 |
+
patient_data = {
|
| 303 |
+
'name': name, # Include name in patient data
|
| 304 |
+
'age': age,
|
| 305 |
+
'gender': gender,
|
| 306 |
+
'blood_group': blood_group,
|
| 307 |
+
'weight': weight,
|
| 308 |
+
'symptoms': symptoms,
|
| 309 |
+
'severity': severity
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
# Get prediction
|
| 313 |
+
prediction = predict_patient_health_profile(patient_data, resources)
|
| 314 |
+
|
| 315 |
+
# Display results in three columns
|
| 316 |
+
st.subheader(f"🔍 Health Profile Analysis Results for {prediction['patient_name']}")
|
| 317 |
+
|
| 318 |
+
col1, col2, col3 = st.columns([1, 1, 1])
|
| 319 |
+
|
| 320 |
+
# Column 1: Disease information
|
| 321 |
+
with col1:
|
| 322 |
+
st.markdown("### 🦠 Disease Prediction")
|
| 323 |
+
st.markdown(f"**Predicted Disease**: {prediction['predicted_disease']}")
|
| 324 |
+
|
| 325 |
+
with st.expander("Disease Causes"):
|
| 326 |
+
st.write(prediction['disease_causes'])
|
| 327 |
+
|
| 328 |
+
with st.expander("Prevention Methods"):
|
| 329 |
+
st.write(prediction['disease_prevention'])
|
| 330 |
+
|
| 331 |
+
# Column 2: Medication recommendations
|
| 332 |
+
with col2:
|
| 333 |
+
st.markdown("### 💊 Medication Recommendations")
|
| 334 |
+
for i, med in enumerate(prediction['medications']):
|
| 335 |
+
st.markdown(f"**{i+1}. {med['medication']}** (Confidence: {med['confidence']:.2f})")
|
| 336 |
+
med_details = f"""
|
| 337 |
+
- **Dosage:** {med['dosage']}
|
| 338 |
+
- **Frequency:** {med['frequency']}
|
| 339 |
+
- **Instructions:** {med['instruction']}
|
| 340 |
+
- **Duration:** {med['duration']}
|
| 341 |
+
"""
|
| 342 |
+
st.markdown(med_details)
|
| 343 |
+
st.divider()
|
| 344 |
+
|
| 345 |
+
# Column 3: Risk assessment and health tips
|
| 346 |
+
with col3:
|
| 347 |
+
st.markdown("### ⚠️ Polypharmacy Assessment")
|
| 348 |
+
risk_color = "green" if prediction['polypharmacy_risk'] == "Low" else "orange" if prediction['polypharmacy_risk'] == "Medium" else "red"
|
| 349 |
+
st.markdown(f"**Risk Level**: <span style='color:{risk_color};font-weight:bold;'>{prediction['polypharmacy_risk']}</span>",
|
| 350 |
+
unsafe_allow_html=True)
|
| 351 |
+
st.markdown(f"**Recommendation**: {prediction['polypharmacy_recommendation']}")
|
| 352 |
+
|
| 353 |
+
st.markdown("### 🌿 Personalized Health Tips")
|
| 354 |
+
st.info(prediction['personalized_health_tips'])
|
| 355 |
+
|
| 356 |
+
# Display medication probabilities as text with progress bars
|
| 357 |
+
st.subheader("Medication Confidence Scores")
|
| 358 |
+
med_names = list(prediction['medication_probabilities'].keys())
|
| 359 |
+
med_probs = list(prediction['medication_probabilities'].values())
|
| 360 |
+
|
| 361 |
+
# Display each medication with its confidence score as text and progress bar
|
| 362 |
+
for med_name, med_prob in zip(med_names, med_probs):
|
| 363 |
+
st.text(f"{med_name}: {med_prob:.2f}")
|
| 364 |
+
st.progress(med_prob)
|
| 365 |
+
|
| 366 |
+
except Exception as e:
|
| 367 |
+
st.error(f"An error occurred: {str(e)}")
|
| 368 |
+
st.error("Please make sure all model files are correctly placed in the 'streamlit_model' directory")
|
| 369 |
+
|
| 370 |
+
if __name__ == "__main__":
|
| 371 |
+
main()
|