File size: 16,144 Bytes
7a8bc6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import streamlit as st
import torch
import numpy as np
import pickle
import json
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
import os

# Set page config
st.set_page_config(
    page_title="Drug Prediction and Polypharmacy System", 
    page_icon="๐Ÿ’Š",
    layout="wide"
)

# Model class definition - must match the training model architecture
class EnhancedMedicationModel(nn.Module):
    def __init__(self, model_name, num_medications, num_polypharmacy_classes, num_disease_classes, dropout_rate=0.3):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout_rate)
        hidden_size = self.bert.config.hidden_size

        # Common representation layer
        self.common_dense = nn.Linear(hidden_size, hidden_size)

        # Task-specific layers with increased complexity
        # Medication prediction head (multi-label)
        self.medication_classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size//2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size//2, num_medications)
        )

        # Polypharmacy risk head (multi-class)
        self.polypharmacy_classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size//2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size//2, num_polypharmacy_classes)
        )

        # Disease prediction head (multi-class)
        self.disease_classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size//2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size//2, num_disease_classes)
        )

        # Apply weight initialization
        self._init_weights()

    def _init_weights(self):
        # Initialize weights for better convergence
        for module in [self.medication_classifier, self.polypharmacy_classifier,
                      self.disease_classifier, self.common_dense]:
            if isinstance(module, nn.Sequential):
                for layer in module:
                    if isinstance(layer, nn.Linear):
                        nn.init.xavier_normal_(layer.weight)
                        nn.init.zeros_(layer.bias)
            elif isinstance(module, nn.Linear):
                nn.init.xavier_normal_(module.weight)
                nn.init.zeros_(layer.bias)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]  # CLS token
        pooled_output = self.dropout(pooled_output)

        # Common representation
        common_features = torch.relu(self.common_dense(pooled_output))

        medication_logits = self.medication_classifier(common_features)
        polypharmacy_logits = self.polypharmacy_classifier(common_features)
        disease_logits = self.disease_classifier(common_features)

        return medication_logits, polypharmacy_logits, disease_logits

@st.cache_resource
def load_model_and_resources():
    """Load model and necessary resources (cached for performance)"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load model configuration - fixed file paths
    with open('streamlit_model/model_config.json', 'r') as f:
        model_config = json.load(f)
    
    # Initialize model
    model_name = model_config['model_name']
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Create model architecture
    model = EnhancedMedicationModel(
        model_name=model_name,
        num_medications=model_config['num_medications'],
        num_polypharmacy_classes=model_config['num_polypharmacy_classes'],
        num_disease_classes=model_config['num_disease_classes'],
        dropout_rate=0.3
    )
    
    # Load trained weights - fixed file path
    model.load_state_dict(torch.load('streamlit_model/model_state_dict.pt', map_location=device))
    model = model.to(device)
    model.eval()
    
    # Load encoders - fixed file path
    with open('streamlit_model/label_encoders.pkl', 'rb') as f:
        encoders = pickle.load(f)
    
    # Load lookup data - fixed file path
    with open('streamlit_model/lookup_data.pkl', 'rb') as f:
        lookup_data = pickle.load(f)
    
    return {
        'model': model,
        'tokenizer': tokenizer,
        'mlb': encoders['mlb'],
        'le_risk': encoders['le_risk'],
        'le_disease': encoders['le_disease'],
        'lookup_data': lookup_data,
        'device': device
    }

def predict_patient_health_profile(patient_data, resources):
    """

    Predict health profile for a patient based on input data

    """
    model = resources['model']
    tokenizer = resources['tokenizer']
    mlb = resources['mlb']
    le_risk = resources['le_risk']
    le_disease = resources['le_disease']
    lookup_data = resources['lookup_data']
    device = resources['device']
    
    # Create text input
    text_input = f"Patient age {patient_data['age']}, gender {patient_data['gender']}, blood group {patient_data['blood_group']}, weight {patient_data['weight']}kg. " +                 f"SYMPTOMS: {patient_data['symptoms']}. " +                 f"SEVERITY: {patient_data['severity']}."
    
    # Tokenize
    encoding = tokenizer(
        text_input,
        add_special_tokens=True,
        max_length=256,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    # Move to device
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    # Get predictions
    with torch.no_grad():
        medication_logits, polypharmacy_logits, disease_logits = model(input_ids, attention_mask)
        medication_preds = torch.sigmoid(medication_logits) > 0.5
        polypharmacy_pred = torch.argmax(polypharmacy_logits, dim=1)
        disease_pred = torch.argmax(disease_logits, dim=1)
    
    # Convert predictions to human-readable format
    predicted_medications = mlb.classes_[medication_preds[0].cpu().numpy()]
    predicted_risk = le_risk.classes_[polypharmacy_pred.item()]
    predicted_disease = le_disease.classes_[disease_pred.item()]
    
    # Get medication probabilities for all medications
    medication_probs = torch.sigmoid(medication_logits).cpu().numpy()[0]
    med_prob_dict = {med: prob for med, prob in zip(mlb.classes_, medication_probs)}
    
    # Sort medications by probability
    sorted_meds = sorted(med_prob_dict.items(), key=lambda x: x[1], reverse=True)
    top_meds = sorted_meds[:5]  # Get top 5 medications
    
    # Format medication results
    med_results = []
    for i, med in enumerate(predicted_medications[:3]):
        med_details = {
            'medication': med,
            'dosage': 'Consult doctor',
            'frequency': 'Consult doctor',
            'instruction': 'Consult doctor',
            'duration': 'As prescribed',
            'confidence': float(med_prob_dict[med])
        }
        med_results.append(med_details)
    
    # Get disease information
    disease_causes = lookup_data['disease_causes_dict'].get(predicted_disease, "Unknown causes")
    disease_prevention = lookup_data['disease_prevention_dict'].get(predicted_disease, "Consult healthcare provider")
    
    # Get polypharmacy recommendation
    polypharmacy_recommendation = lookup_data['polypharmacy_recommendation_dict'].get(
        predicted_risk, "Consult healthcare provider"
    )
    
    # Get personalized health tip
    age_decade = (patient_data['age'] // 10) * 10
    health_tip_key = (predicted_disease, age_decade, patient_data['gender'])
    personalized_health_tip = lookup_data['health_tips_dict'].get(
        health_tip_key, "Maintain a balanced diet and regular exercise routine."
    )
    
    # Return comprehensive results
    return {
        'patient_name': patient_data['name'],  # Include patient name in results
        'predicted_disease': predicted_disease,
        'disease_causes': disease_causes,
        'disease_prevention': disease_prevention,
        'medications': med_results,
        'polypharmacy_risk': predicted_risk,
        'polypharmacy_recommendation': polypharmacy_recommendation,
        'personalized_health_tips': personalized_health_tip,
        'medication_probabilities': {med: float(prob) for med, prob in top_meds}
    }

def main():
    # App title and description
    st.title("๐Ÿฅ Drug Prediction and Polypharmacy System")
    st.markdown("Enter patient information to receive medication recommendations, disease prediction, and polypharmacy risk assessment.")
    
    try:
        # Load model and resources
        with st.spinner("Loading medical model and resources..."):
            resources = load_model_and_resources()
        
        # Create two columns for input form
        col1, col2 = st.columns(2)
        
        # Patient information inputs
        with col1:
            st.subheader("Patient Information")
            # Add patient name input field
            name = st.text_input("Patient Name", value="John Doe")
            age = st.number_input("Age", min_value=1, max_value=120, value=45)
            gender = st.selectbox("Gender", options=["Male", "Female", "Other"])
            blood_group = st.selectbox("Blood Group", options=["A+", "A-", "B+", "B-", "AB+", "AB-", "O+", "O-"])
            weight = st.number_input("Weight (kg)", min_value=1.0, max_value=300.0, value=70.0, step=0.1)
        
        with col2:
            st.subheader("Symptoms Information")
            
            # Common symptoms options
            common_symptoms = [
                "Headache", "Fever", "Fatigue", "Nausea", "Cough", 
                "Sore throat", "Shortness of breath", "Chest pain", 
                "Dizziness", "Abdominal pain", "Vomiting", "Diarrhea",
                "Muscle ache", "Joint pain", "Rash", "Loss of appetite"
            ]
            
            # Use multiselect for symptoms selection
            selected_symptoms = st.multiselect(
                "Select Symptoms", 
                options=common_symptoms,
                default=["Headache", "Fever", "Fatigue"]
            )
            
            # Custom symptom input
            custom_symptom = st.text_input("Add other symptom (if not in list)")
            if custom_symptom:
                selected_symptoms.append(custom_symptom)
            
            # Convert selected symptoms to string format as expected by the model
            symptoms = "; ".join(selected_symptoms)
            
            # More compact severity selection
            st.subheader("Symptom Severity")
            
            # Define severity levels
            severity_levels = {
                "Very Mild": 1,
                "Mild": 2,
                "Moderate": 3,
                "Severe": 4,
                "Very Severe": 5
            }
            
            severity_dict = {}
            
            # Create a more compact layout with 2 columns for severity selection
            if selected_symptoms:
                cols = st.columns(2)
                for i, symptom in enumerate(selected_symptoms):
                    # Alternate between columns
                    with cols[i % 2]:
                        severity_option = st.selectbox(
                            f"{symptom}",
                            options=list(severity_levels.keys()),
                            index=1  # Default to "Mild"
                        )
                        severity_dict[symptom] = severity_levels[severity_option]
            
            # Convert severity dict to string format as expected by the model
            severity = "; ".join([f"{symptom}:{score}" for symptom, score in severity_dict.items()])
            
        # Submit button
        if st.button("Generate Health Profile", type="primary"):
            with st.spinner("Analyzing patient data and generating health profile..."):
                # Prepare patient data
                patient_data = {
                    'name': name,  # Include name in patient data
                    'age': age,
                    'gender': gender,
                    'blood_group': blood_group,
                    'weight': weight,
                    'symptoms': symptoms,
                    'severity': severity
                }
                
                # Get prediction
                prediction = predict_patient_health_profile(patient_data, resources)
                
                # Display results in three columns
                st.subheader(f"๐Ÿ” Health Profile Analysis Results for {prediction['patient_name']}")
                
                col1, col2, col3 = st.columns([1, 1, 1])
                
                # Column 1: Disease information
                with col1:
                    st.markdown("### ๐Ÿฆ  Disease Prediction")
                    st.markdown(f"**Predicted Disease**: {prediction['predicted_disease']}")
                    
                    with st.expander("Disease Causes"):
                        st.write(prediction['disease_causes'])
                    
                    with st.expander("Prevention Methods"):
                        st.write(prediction['disease_prevention'])
                
                # Column 2: Medication recommendations
                with col2:
                    st.markdown("### ๐Ÿ’Š Medication Recommendations")
                    for i, med in enumerate(prediction['medications']):
                        st.markdown(f"**{i+1}. {med['medication']}** (Confidence: {med['confidence']:.2f})")
                        med_details = f"""

                        - **Dosage:** {med['dosage']}

                        - **Frequency:** {med['frequency']}

                        - **Instructions:** {med['instruction']}

                        - **Duration:** {med['duration']}

                        """
                        st.markdown(med_details)
                        st.divider()
                
                # Column 3: Risk assessment and health tips
                with col3:
                    st.markdown("### โš ๏ธ Polypharmacy Assessment")
                    risk_color = "green" if prediction['polypharmacy_risk'] == "Low" else                                 "orange" if prediction['polypharmacy_risk'] == "Medium" else "red"
                    st.markdown(f"**Risk Level**: <span style='color:{risk_color};font-weight:bold;'>{prediction['polypharmacy_risk']}</span>", 
                                unsafe_allow_html=True)
                    st.markdown(f"**Recommendation**: {prediction['polypharmacy_recommendation']}")
                    
                    st.markdown("### ๐ŸŒฟ Personalized Health Tips")
                    st.info(prediction['personalized_health_tips'])
                    
                # Display medication probabilities as text with progress bars
                st.subheader("Medication Confidence Scores")
                med_names = list(prediction['medication_probabilities'].keys())
                med_probs = list(prediction['medication_probabilities'].values())
                
                # Display each medication with its confidence score as text and progress bar
                for med_name, med_prob in zip(med_names, med_probs):
                    st.text(f"{med_name}: {med_prob:.2f}")
                    st.progress(med_prob)
                    
    except Exception as e:
        st.error(f"An error occurred: {str(e)}")
        st.error("Please make sure all model files are correctly placed in the 'streamlit_model' directory")

if __name__ == "__main__":
    main()