import streamlit as st import torch import pickle import numpy as np from pathlib import Path from sklearn.preprocessing import LabelEncoder from torch.nn.functional import softmax from model_folder.modeling_bilstm import BiLSTM from tensorflow.keras.preprocessing.sequence import pad_sequences import json import pandas as pd # Constants DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load Model and Tokenizer @st.cache_resource def load_model_and_tokenizer(): model_path = Path("model_folder") tokenizer_path = model_path / "tokenizer.pkl" config_path = model_path / "config.json" model_weights_path = model_path / "pytorch_model.bin" # Load tokenizer with open(tokenizer_path, "rb") as f: tokenizer = pickle.load(f) # Load model config with open(config_path, "r") as f: config = json.load(f) # Get vocabulary size from the tokenizer vocab_size = len(tokenizer.word_index) + 1 # +1 for padding token # Initialize embedding matrix (vocab_size x embed_size) embedding_matrix = torch.zeros((vocab_size, config["embedding_dim"])) # Load model model = BiLSTM( embedding_matrix=embedding_matrix, embed_size=config["embedding_dim"], num_labels=config["num_labels"], hidden_size=config["hidden_size"], dropout=config["dropout_rate"] ) model.load_state_dict(torch.load(model_weights_path, map_location=DEVICE)) model.eval() model.to(DEVICE) # Label encoder with provided class names le = LabelEncoder() le.classes_ = np.array([ "ADHD", "Acne", "Anxiety", "Bipolar Disorde", "Birth Control", "Depression", "Diabetes, Type 2", "Emergency Contraception", "High Blood Pressure", "Insomnia", "Obesity", "Pain", "Vaginal Yeast Infection", "Weight Loss" ]) return model, tokenizer, le, config # Predict Condition def predict_condition(model, tokenizer, le, text, max_length): sequence = tokenizer.texts_to_sequences([text]) padded_sequence = pad_sequences(sequence, maxlen=max_length) tensor = torch.tensor(padded_sequence, dtype=torch.long).to(DEVICE) with torch.no_grad(): logits = model(tensor) probabilities = softmax(logits, dim=1) predicted_class = torch.argmax(probabilities, axis=1).item() predicted_condition = le.inverse_transform([predicted_class])[0] confidence = probabilities[0][predicted_class].item() return predicted_condition, confidence # Recommend Drugs def recommend_drugs(predicted_condition, df, top_n=5): # Filter the dataset for the predicted condition condition_df = df[df['condition'] == predicted_condition] # Filter for positive reviews (e.g., rating >= 8) positive_reviews_df = condition_df[condition_df['rating'] >= 8] # Group by drug and calculate the number of positive reviews and average rating drug_stats = positive_reviews_df.groupby('drugName').agg( num_reviews=('rating', 'size'), # Count the number of reviews avg_rating=('rating', 'mean') # Calculate the average rating ).reset_index() # Sort by number of reviews and average rating drug_stats = drug_stats.sort_values(by=['num_reviews', 'avg_rating'], ascending=False) # Return the top N drugs return drug_stats.head(top_n) # Streamlit UI def main(): st.title("Condition Prediction and Drug Recommendation") st.write("Enter symptoms below to predict the medical condition and get drug recommendations.") # Load the dataset df = pd.read_csv('dataset/custom.csv') model, tokenizer, le, config = load_model_and_tokenizer() user_input = st.text_area("Enter your symptoms", "") if st.button("Predict"): if user_input.strip(): with st.spinner("Predicting..."): predicted_condition, confidence = predict_condition(model, tokenizer, le, user_input, config["max_length"]) st.success(f"Predicted Condition: {predicted_condition} (Confidence: {confidence:.2f})") # Get drug recommendations top_drugs = recommend_drugs(predicted_condition, df, top_n=5) # Display recommended drugs st.write("**Top Recommended Drugs**") if not top_drugs.empty: for i, row in top_drugs.iterrows(): st.write(f"- {row['drugName']} (Positive Reviews: {row['num_reviews']}, Avg Rating: {row['avg_rating']:.2f})") else: st.write("No drugs found for this condition.") else: st.warning("Please enter a review.") if __name__ == "__main__": main()