|
|
import streamlit as st
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
import pickle
|
|
|
import json
|
|
|
from transformers import AutoTokenizer, AutoModel
|
|
|
import torch.nn as nn
|
|
|
import os
|
|
|
|
|
|
|
|
|
st.set_page_config(
|
|
|
page_title="Drug Prediction and Polypharmacy System",
|
|
|
page_icon="๐",
|
|
|
layout="wide"
|
|
|
)
|
|
|
|
|
|
|
|
|
class EnhancedMedicationModel(nn.Module):
|
|
|
def __init__(self, model_name, num_medications, num_polypharmacy_classes, num_disease_classes, dropout_rate=0.3):
|
|
|
super().__init__()
|
|
|
self.bert = AutoModel.from_pretrained(model_name)
|
|
|
self.dropout = nn.Dropout(dropout_rate)
|
|
|
hidden_size = self.bert.config.hidden_size
|
|
|
|
|
|
|
|
|
self.common_dense = nn.Linear(hidden_size, hidden_size)
|
|
|
|
|
|
|
|
|
|
|
|
self.medication_classifier = nn.Sequential(
|
|
|
nn.Linear(hidden_size, hidden_size//2),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(dropout_rate),
|
|
|
nn.Linear(hidden_size//2, num_medications)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.polypharmacy_classifier = nn.Sequential(
|
|
|
nn.Linear(hidden_size, hidden_size//2),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(dropout_rate),
|
|
|
nn.Linear(hidden_size//2, num_polypharmacy_classes)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.disease_classifier = nn.Sequential(
|
|
|
nn.Linear(hidden_size, hidden_size//2),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(dropout_rate),
|
|
|
nn.Linear(hidden_size//2, num_disease_classes)
|
|
|
)
|
|
|
|
|
|
|
|
|
self._init_weights()
|
|
|
|
|
|
def _init_weights(self):
|
|
|
|
|
|
for module in [self.medication_classifier, self.polypharmacy_classifier,
|
|
|
self.disease_classifier, self.common_dense]:
|
|
|
if isinstance(module, nn.Sequential):
|
|
|
for layer in module:
|
|
|
if isinstance(layer, nn.Linear):
|
|
|
nn.init.xavier_normal_(layer.weight)
|
|
|
nn.init.zeros_(layer.bias)
|
|
|
elif isinstance(module, nn.Linear):
|
|
|
nn.init.xavier_normal_(module.weight)
|
|
|
nn.init.zeros_(layer.bias)
|
|
|
|
|
|
def forward(self, input_ids, attention_mask):
|
|
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
pooled_output = outputs.last_hidden_state[:, 0, :]
|
|
|
pooled_output = self.dropout(pooled_output)
|
|
|
|
|
|
|
|
|
common_features = torch.relu(self.common_dense(pooled_output))
|
|
|
|
|
|
medication_logits = self.medication_classifier(common_features)
|
|
|
polypharmacy_logits = self.polypharmacy_classifier(common_features)
|
|
|
disease_logits = self.disease_classifier(common_features)
|
|
|
|
|
|
return medication_logits, polypharmacy_logits, disease_logits
|
|
|
|
|
|
@st.cache_resource
|
|
|
def load_model_and_resources():
|
|
|
"""Load model and necessary resources (cached for performance)"""
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
with open('streamlit_model/model_config.json', 'r') as f:
|
|
|
model_config = json.load(f)
|
|
|
|
|
|
|
|
|
model_name = model_config['model_name']
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
model = EnhancedMedicationModel(
|
|
|
model_name=model_name,
|
|
|
num_medications=model_config['num_medications'],
|
|
|
num_polypharmacy_classes=model_config['num_polypharmacy_classes'],
|
|
|
num_disease_classes=model_config['num_disease_classes'],
|
|
|
dropout_rate=0.3
|
|
|
)
|
|
|
|
|
|
|
|
|
model.load_state_dict(torch.load('streamlit_model/model_state_dict.pt', map_location=device))
|
|
|
model = model.to(device)
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
with open('streamlit_model/label_encoders.pkl', 'rb') as f:
|
|
|
encoders = pickle.load(f)
|
|
|
|
|
|
|
|
|
with open('streamlit_model/lookup_data.pkl', 'rb') as f:
|
|
|
lookup_data = pickle.load(f)
|
|
|
|
|
|
return {
|
|
|
'model': model,
|
|
|
'tokenizer': tokenizer,
|
|
|
'mlb': encoders['mlb'],
|
|
|
'le_risk': encoders['le_risk'],
|
|
|
'le_disease': encoders['le_disease'],
|
|
|
'lookup_data': lookup_data,
|
|
|
'device': device
|
|
|
}
|
|
|
|
|
|
def predict_patient_health_profile(patient_data, resources):
|
|
|
"""
|
|
|
Predict health profile for a patient based on input data
|
|
|
"""
|
|
|
model = resources['model']
|
|
|
tokenizer = resources['tokenizer']
|
|
|
mlb = resources['mlb']
|
|
|
le_risk = resources['le_risk']
|
|
|
le_disease = resources['le_disease']
|
|
|
lookup_data = resources['lookup_data']
|
|
|
device = resources['device']
|
|
|
|
|
|
|
|
|
text_input = f"Patient age {patient_data['age']}, gender {patient_data['gender']}, blood group {patient_data['blood_group']}, weight {patient_data['weight']}kg. " + f"SYMPTOMS: {patient_data['symptoms']}. " + f"SEVERITY: {patient_data['severity']}."
|
|
|
|
|
|
|
|
|
encoding = tokenizer(
|
|
|
text_input,
|
|
|
add_special_tokens=True,
|
|
|
max_length=256,
|
|
|
padding='max_length',
|
|
|
truncation=True,
|
|
|
return_tensors='pt'
|
|
|
)
|
|
|
|
|
|
|
|
|
input_ids = encoding['input_ids'].to(device)
|
|
|
attention_mask = encoding['attention_mask'].to(device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
medication_logits, polypharmacy_logits, disease_logits = model(input_ids, attention_mask)
|
|
|
medication_preds = torch.sigmoid(medication_logits) > 0.5
|
|
|
polypharmacy_pred = torch.argmax(polypharmacy_logits, dim=1)
|
|
|
disease_pred = torch.argmax(disease_logits, dim=1)
|
|
|
|
|
|
|
|
|
predicted_medications = mlb.classes_[medication_preds[0].cpu().numpy()]
|
|
|
predicted_risk = le_risk.classes_[polypharmacy_pred.item()]
|
|
|
predicted_disease = le_disease.classes_[disease_pred.item()]
|
|
|
|
|
|
|
|
|
medication_probs = torch.sigmoid(medication_logits).cpu().numpy()[0]
|
|
|
med_prob_dict = {med: prob for med, prob in zip(mlb.classes_, medication_probs)}
|
|
|
|
|
|
|
|
|
sorted_meds = sorted(med_prob_dict.items(), key=lambda x: x[1], reverse=True)
|
|
|
top_meds = sorted_meds[:5]
|
|
|
|
|
|
|
|
|
med_results = []
|
|
|
for i, med in enumerate(predicted_medications[:3]):
|
|
|
med_details = {
|
|
|
'medication': med,
|
|
|
'dosage': 'Consult doctor',
|
|
|
'frequency': 'Consult doctor',
|
|
|
'instruction': 'Consult doctor',
|
|
|
'duration': 'As prescribed',
|
|
|
'confidence': float(med_prob_dict[med])
|
|
|
}
|
|
|
med_results.append(med_details)
|
|
|
|
|
|
|
|
|
disease_causes = lookup_data['disease_causes_dict'].get(predicted_disease, "Unknown causes")
|
|
|
disease_prevention = lookup_data['disease_prevention_dict'].get(predicted_disease, "Consult healthcare provider")
|
|
|
|
|
|
|
|
|
polypharmacy_recommendation = lookup_data['polypharmacy_recommendation_dict'].get(
|
|
|
predicted_risk, "Consult healthcare provider"
|
|
|
)
|
|
|
|
|
|
|
|
|
age_decade = (patient_data['age'] // 10) * 10
|
|
|
health_tip_key = (predicted_disease, age_decade, patient_data['gender'])
|
|
|
personalized_health_tip = lookup_data['health_tips_dict'].get(
|
|
|
health_tip_key, "Maintain a balanced diet and regular exercise routine."
|
|
|
)
|
|
|
|
|
|
|
|
|
return {
|
|
|
'patient_name': patient_data['name'],
|
|
|
'predicted_disease': predicted_disease,
|
|
|
'disease_causes': disease_causes,
|
|
|
'disease_prevention': disease_prevention,
|
|
|
'medications': med_results,
|
|
|
'polypharmacy_risk': predicted_risk,
|
|
|
'polypharmacy_recommendation': polypharmacy_recommendation,
|
|
|
'personalized_health_tips': personalized_health_tip,
|
|
|
'medication_probabilities': {med: float(prob) for med, prob in top_meds}
|
|
|
}
|
|
|
|
|
|
def main():
|
|
|
|
|
|
st.title("๐ฅ Drug Prediction and Polypharmacy System")
|
|
|
st.markdown("Enter patient information to receive medication recommendations, disease prediction, and polypharmacy risk assessment.")
|
|
|
|
|
|
try:
|
|
|
|
|
|
with st.spinner("Loading medical model and resources..."):
|
|
|
resources = load_model_and_resources()
|
|
|
|
|
|
|
|
|
col1, col2 = st.columns(2)
|
|
|
|
|
|
|
|
|
with col1:
|
|
|
st.subheader("Patient Information")
|
|
|
|
|
|
name = st.text_input("Patient Name", value="John Doe")
|
|
|
age = st.number_input("Age", min_value=1, max_value=120, value=45)
|
|
|
gender = st.selectbox("Gender", options=["Male", "Female", "Other"])
|
|
|
blood_group = st.selectbox("Blood Group", options=["A+", "A-", "B+", "B-", "AB+", "AB-", "O+", "O-"])
|
|
|
weight = st.number_input("Weight (kg)", min_value=1.0, max_value=300.0, value=70.0, step=0.1)
|
|
|
|
|
|
with col2:
|
|
|
st.subheader("Symptoms Information")
|
|
|
|
|
|
|
|
|
common_symptoms = [
|
|
|
"Headache", "Fever", "Fatigue", "Nausea", "Cough",
|
|
|
"Sore throat", "Shortness of breath", "Chest pain",
|
|
|
"Dizziness", "Abdominal pain", "Vomiting", "Diarrhea",
|
|
|
"Muscle ache", "Joint pain", "Rash", "Loss of appetite"
|
|
|
]
|
|
|
|
|
|
|
|
|
selected_symptoms = st.multiselect(
|
|
|
"Select Symptoms",
|
|
|
options=common_symptoms,
|
|
|
default=["Headache", "Fever", "Fatigue"]
|
|
|
)
|
|
|
|
|
|
|
|
|
custom_symptom = st.text_input("Add other symptom (if not in list)")
|
|
|
if custom_symptom:
|
|
|
selected_symptoms.append(custom_symptom)
|
|
|
|
|
|
|
|
|
symptoms = "; ".join(selected_symptoms)
|
|
|
|
|
|
|
|
|
st.subheader("Symptom Severity")
|
|
|
|
|
|
|
|
|
severity_levels = {
|
|
|
"Very Mild": 1,
|
|
|
"Mild": 2,
|
|
|
"Moderate": 3,
|
|
|
"Severe": 4,
|
|
|
"Very Severe": 5
|
|
|
}
|
|
|
|
|
|
severity_dict = {}
|
|
|
|
|
|
|
|
|
if selected_symptoms:
|
|
|
cols = st.columns(2)
|
|
|
for i, symptom in enumerate(selected_symptoms):
|
|
|
|
|
|
with cols[i % 2]:
|
|
|
severity_option = st.selectbox(
|
|
|
f"{symptom}",
|
|
|
options=list(severity_levels.keys()),
|
|
|
index=1
|
|
|
)
|
|
|
severity_dict[symptom] = severity_levels[severity_option]
|
|
|
|
|
|
|
|
|
severity = "; ".join([f"{symptom}:{score}" for symptom, score in severity_dict.items()])
|
|
|
|
|
|
|
|
|
if st.button("Generate Health Profile", type="primary"):
|
|
|
with st.spinner("Analyzing patient data and generating health profile..."):
|
|
|
|
|
|
patient_data = {
|
|
|
'name': name,
|
|
|
'age': age,
|
|
|
'gender': gender,
|
|
|
'blood_group': blood_group,
|
|
|
'weight': weight,
|
|
|
'symptoms': symptoms,
|
|
|
'severity': severity
|
|
|
}
|
|
|
|
|
|
|
|
|
prediction = predict_patient_health_profile(patient_data, resources)
|
|
|
|
|
|
|
|
|
st.subheader(f"๐ Health Profile Analysis Results for {prediction['patient_name']}")
|
|
|
|
|
|
col1, col2, col3 = st.columns([1, 1, 1])
|
|
|
|
|
|
|
|
|
with col1:
|
|
|
st.markdown("### ๐ฆ Disease Prediction")
|
|
|
st.markdown(f"**Predicted Disease**: {prediction['predicted_disease']}")
|
|
|
|
|
|
with st.expander("Disease Causes"):
|
|
|
st.write(prediction['disease_causes'])
|
|
|
|
|
|
with st.expander("Prevention Methods"):
|
|
|
st.write(prediction['disease_prevention'])
|
|
|
|
|
|
|
|
|
with col2:
|
|
|
st.markdown("### ๐ Medication Recommendations")
|
|
|
for i, med in enumerate(prediction['medications']):
|
|
|
st.markdown(f"**{i+1}. {med['medication']}** (Confidence: {med['confidence']:.2f})")
|
|
|
med_details = f"""
|
|
|
- **Dosage:** {med['dosage']}
|
|
|
- **Frequency:** {med['frequency']}
|
|
|
- **Instructions:** {med['instruction']}
|
|
|
- **Duration:** {med['duration']}
|
|
|
"""
|
|
|
st.markdown(med_details)
|
|
|
st.divider()
|
|
|
|
|
|
|
|
|
with col3:
|
|
|
st.markdown("### โ ๏ธ Polypharmacy Assessment")
|
|
|
risk_color = "green" if prediction['polypharmacy_risk'] == "Low" else "orange" if prediction['polypharmacy_risk'] == "Medium" else "red"
|
|
|
st.markdown(f"**Risk Level**: <span style='color:{risk_color};font-weight:bold;'>{prediction['polypharmacy_risk']}</span>",
|
|
|
unsafe_allow_html=True)
|
|
|
st.markdown(f"**Recommendation**: {prediction['polypharmacy_recommendation']}")
|
|
|
|
|
|
st.markdown("### ๐ฟ Personalized Health Tips")
|
|
|
st.info(prediction['personalized_health_tips'])
|
|
|
|
|
|
|
|
|
st.subheader("Medication Confidence Scores")
|
|
|
med_names = list(prediction['medication_probabilities'].keys())
|
|
|
med_probs = list(prediction['medication_probabilities'].values())
|
|
|
|
|
|
|
|
|
for med_name, med_prob in zip(med_names, med_probs):
|
|
|
st.text(f"{med_name}: {med_prob:.2f}")
|
|
|
st.progress(med_prob)
|
|
|
|
|
|
except Exception as e:
|
|
|
st.error(f"An error occurred: {str(e)}")
|
|
|
st.error("Please make sure all model files are correctly placed in the 'streamlit_model' directory")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|