import streamlit as st import torch import numpy as np import pickle import json from transformers import AutoTokenizer, AutoModel import torch.nn as nn import os # Set page config st.set_page_config( page_title="Drug Prediction and Polypharmacy System", page_icon="💊", layout="wide" ) # Model class definition - must match the training model architecture class EnhancedMedicationModel(nn.Module): def __init__(self, model_name, num_medications, num_polypharmacy_classes, num_disease_classes, dropout_rate=0.3): super().__init__() self.bert = AutoModel.from_pretrained(model_name) self.dropout = nn.Dropout(dropout_rate) hidden_size = self.bert.config.hidden_size # Common representation layer self.common_dense = nn.Linear(hidden_size, hidden_size) # Task-specific layers with increased complexity # Medication prediction head (multi-label) self.medication_classifier = nn.Sequential( nn.Linear(hidden_size, hidden_size//2), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(hidden_size//2, num_medications) ) # Polypharmacy risk head (multi-class) self.polypharmacy_classifier = nn.Sequential( nn.Linear(hidden_size, hidden_size//2), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(hidden_size//2, num_polypharmacy_classes) ) # Disease prediction head (multi-class) self.disease_classifier = nn.Sequential( nn.Linear(hidden_size, hidden_size//2), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(hidden_size//2, num_disease_classes) ) # Apply weight initialization self._init_weights() def _init_weights(self): # Initialize weights for better convergence for module in [self.medication_classifier, self.polypharmacy_classifier, self.disease_classifier, self.common_dense]: if isinstance(module, nn.Sequential): for layer in module: if isinstance(layer, nn.Linear): nn.init.xavier_normal_(layer.weight) nn.init.zeros_(layer.bias) elif isinstance(module, nn.Linear): nn.init.xavier_normal_(module.weight) nn.init.zeros_(layer.bias) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs.last_hidden_state[:, 0, :] # CLS token pooled_output = self.dropout(pooled_output) # Common representation common_features = torch.relu(self.common_dense(pooled_output)) medication_logits = self.medication_classifier(common_features) polypharmacy_logits = self.polypharmacy_classifier(common_features) disease_logits = self.disease_classifier(common_features) return medication_logits, polypharmacy_logits, disease_logits @st.cache_resource def load_model_and_resources(): """Load model and necessary resources (cached for performance)""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model configuration - fixed file paths with open('streamlit_model/model_config.json', 'r') as f: model_config = json.load(f) # Initialize model model_name = model_config['model_name'] tokenizer = AutoTokenizer.from_pretrained(model_name) # Create model architecture model = EnhancedMedicationModel( model_name=model_name, num_medications=model_config['num_medications'], num_polypharmacy_classes=model_config['num_polypharmacy_classes'], num_disease_classes=model_config['num_disease_classes'], dropout_rate=0.3 ) # Load trained weights - fixed file path model.load_state_dict(torch.load('streamlit_model/model_state_dict.pt', map_location=device)) model = model.to(device) model.eval() # Load encoders - fixed file path with open('streamlit_model/label_encoders.pkl', 'rb') as f: encoders = pickle.load(f) # Load lookup data - fixed file path with open('streamlit_model/lookup_data.pkl', 'rb') as f: lookup_data = pickle.load(f) return { 'model': model, 'tokenizer': tokenizer, 'mlb': encoders['mlb'], 'le_risk': encoders['le_risk'], 'le_disease': encoders['le_disease'], 'lookup_data': lookup_data, 'device': device } def predict_patient_health_profile(patient_data, resources): """ Predict health profile for a patient based on input data """ model = resources['model'] tokenizer = resources['tokenizer'] mlb = resources['mlb'] le_risk = resources['le_risk'] le_disease = resources['le_disease'] lookup_data = resources['lookup_data'] device = resources['device'] # Create text input text_input = f"Patient age {patient_data['age']}, gender {patient_data['gender']}, blood group {patient_data['blood_group']}, weight {patient_data['weight']}kg. " + f"SYMPTOMS: {patient_data['symptoms']}. " + f"SEVERITY: {patient_data['severity']}." # Tokenize encoding = tokenizer( text_input, add_special_tokens=True, max_length=256, padding='max_length', truncation=True, return_tensors='pt' ) # Move to device input_ids = encoding['input_ids'].to(device) attention_mask = encoding['attention_mask'].to(device) # Get predictions with torch.no_grad(): medication_logits, polypharmacy_logits, disease_logits = model(input_ids, attention_mask) medication_preds = torch.sigmoid(medication_logits) > 0.5 polypharmacy_pred = torch.argmax(polypharmacy_logits, dim=1) disease_pred = torch.argmax(disease_logits, dim=1) # Convert predictions to human-readable format predicted_medications = mlb.classes_[medication_preds[0].cpu().numpy()] predicted_risk = le_risk.classes_[polypharmacy_pred.item()] predicted_disease = le_disease.classes_[disease_pred.item()] # Get medication probabilities for all medications medication_probs = torch.sigmoid(medication_logits).cpu().numpy()[0] med_prob_dict = {med: prob for med, prob in zip(mlb.classes_, medication_probs)} # Sort medications by probability sorted_meds = sorted(med_prob_dict.items(), key=lambda x: x[1], reverse=True) top_meds = sorted_meds[:5] # Get top 5 medications # Format medication results med_results = [] for i, med in enumerate(predicted_medications[:3]): med_details = { 'medication': med, 'dosage': 'Consult doctor', 'frequency': 'Consult doctor', 'instruction': 'Consult doctor', 'duration': 'As prescribed', 'confidence': float(med_prob_dict[med]) } med_results.append(med_details) # Get disease information disease_causes = lookup_data['disease_causes_dict'].get(predicted_disease, "Unknown causes") disease_prevention = lookup_data['disease_prevention_dict'].get(predicted_disease, "Consult healthcare provider") # Get polypharmacy recommendation polypharmacy_recommendation = lookup_data['polypharmacy_recommendation_dict'].get( predicted_risk, "Consult healthcare provider" ) # Get personalized health tip age_decade = (patient_data['age'] // 10) * 10 health_tip_key = (predicted_disease, age_decade, patient_data['gender']) personalized_health_tip = lookup_data['health_tips_dict'].get( health_tip_key, "Maintain a balanced diet and regular exercise routine." ) # Return comprehensive results return { 'patient_name': patient_data['name'], # Include patient name in results 'predicted_disease': predicted_disease, 'disease_causes': disease_causes, 'disease_prevention': disease_prevention, 'medications': med_results, 'polypharmacy_risk': predicted_risk, 'polypharmacy_recommendation': polypharmacy_recommendation, 'personalized_health_tips': personalized_health_tip, 'medication_probabilities': {med: float(prob) for med, prob in top_meds} } def main(): # App title and description st.title("🏥 Drug Prediction and Polypharmacy System") st.markdown("Enter patient information to receive medication recommendations, disease prediction, and polypharmacy risk assessment.") try: # Load model and resources with st.spinner("Loading medical model and resources..."): resources = load_model_and_resources() # Create two columns for input form col1, col2 = st.columns(2) # Patient information inputs with col1: st.subheader("Patient Information") # Add patient name input field name = st.text_input("Patient Name", value="John Doe") age = st.number_input("Age", min_value=1, max_value=120, value=45) gender = st.selectbox("Gender", options=["Male", "Female", "Other"]) blood_group = st.selectbox("Blood Group", options=["A+", "A-", "B+", "B-", "AB+", "AB-", "O+", "O-"]) weight = st.number_input("Weight (kg)", min_value=1.0, max_value=300.0, value=70.0, step=0.1) with col2: st.subheader("Symptoms Information") # Common symptoms options common_symptoms = [ "Headache", "Fever", "Fatigue", "Nausea", "Cough", "Sore throat", "Shortness of breath", "Chest pain", "Dizziness", "Abdominal pain", "Vomiting", "Diarrhea", "Muscle ache", "Joint pain", "Rash", "Loss of appetite" ] # Use multiselect for symptoms selection selected_symptoms = st.multiselect( "Select Symptoms", options=common_symptoms, default=["Headache", "Fever", "Fatigue"] ) # Custom symptom input custom_symptom = st.text_input("Add other symptom (if not in list)") if custom_symptom: selected_symptoms.append(custom_symptom) # Convert selected symptoms to string format as expected by the model symptoms = "; ".join(selected_symptoms) # More compact severity selection st.subheader("Symptom Severity") # Define severity levels severity_levels = { "Very Mild": 1, "Mild": 2, "Moderate": 3, "Severe": 4, "Very Severe": 5 } severity_dict = {} # Create a more compact layout with 2 columns for severity selection if selected_symptoms: cols = st.columns(2) for i, symptom in enumerate(selected_symptoms): # Alternate between columns with cols[i % 2]: severity_option = st.selectbox( f"{symptom}", options=list(severity_levels.keys()), index=1 # Default to "Mild" ) severity_dict[symptom] = severity_levels[severity_option] # Convert severity dict to string format as expected by the model severity = "; ".join([f"{symptom}:{score}" for symptom, score in severity_dict.items()]) # Submit button if st.button("Generate Health Profile", type="primary"): with st.spinner("Analyzing patient data and generating health profile..."): # Prepare patient data patient_data = { 'name': name, # Include name in patient data 'age': age, 'gender': gender, 'blood_group': blood_group, 'weight': weight, 'symptoms': symptoms, 'severity': severity } # Get prediction prediction = predict_patient_health_profile(patient_data, resources) # Display results in three columns st.subheader(f"🔍 Health Profile Analysis Results for {prediction['patient_name']}") col1, col2, col3 = st.columns([1, 1, 1]) # Column 1: Disease information with col1: st.markdown("### 🦠 Disease Prediction") st.markdown(f"**Predicted Disease**: {prediction['predicted_disease']}") with st.expander("Disease Causes"): st.write(prediction['disease_causes']) with st.expander("Prevention Methods"): st.write(prediction['disease_prevention']) # Column 2: Medication recommendations with col2: st.markdown("### 💊 Medication Recommendations") for i, med in enumerate(prediction['medications']): st.markdown(f"**{i+1}. {med['medication']}** (Confidence: {med['confidence']:.2f})") med_details = f""" - **Dosage:** {med['dosage']} - **Frequency:** {med['frequency']} - **Instructions:** {med['instruction']} - **Duration:** {med['duration']} """ st.markdown(med_details) st.divider() # Column 3: Risk assessment and health tips with col3: st.markdown("### ⚠️ Polypharmacy Assessment") risk_color = "green" if prediction['polypharmacy_risk'] == "Low" else "orange" if prediction['polypharmacy_risk'] == "Medium" else "red" st.markdown(f"**Risk Level**: {prediction['polypharmacy_risk']}", unsafe_allow_html=True) st.markdown(f"**Recommendation**: {prediction['polypharmacy_recommendation']}") st.markdown("### 🌿 Personalized Health Tips") st.info(prediction['personalized_health_tips']) # Display medication probabilities as text with progress bars st.subheader("Medication Confidence Scores") med_names = list(prediction['medication_probabilities'].keys()) med_probs = list(prediction['medication_probabilities'].values()) # Display each medication with its confidence score as text and progress bar for med_name, med_prob in zip(med_names, med_probs): st.text(f"{med_name}: {med_prob:.2f}") st.progress(med_prob) except Exception as e: st.error(f"An error occurred: {str(e)}") st.error("Please make sure all model files are correctly placed in the 'streamlit_model' directory") if __name__ == "__main__": main()