DPDRS-BILSTM / app.py
puneeth1's picture
Update app.py
c7f80fa verified
import streamlit as st
import torch
import pickle
import numpy as np
from pathlib import Path
from sklearn.preprocessing import LabelEncoder
from torch.nn.functional import softmax
from model_folder.modeling_bilstm import BiLSTM
from tensorflow.keras.preprocessing.sequence import pad_sequences
import json
import pandas as pd
# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load Model and Tokenizer
@st.cache_resource
def load_model_and_tokenizer():
model_path = Path("model_folder")
tokenizer_path = model_path / "tokenizer.pkl"
config_path = model_path / "config.json"
model_weights_path = model_path / "pytorch_model.bin"
# Load tokenizer
with open(tokenizer_path, "rb") as f:
tokenizer = pickle.load(f)
# Load model config
with open(config_path, "r") as f:
config = json.load(f)
# Get vocabulary size from the tokenizer
vocab_size = len(tokenizer.word_index) + 1 # +1 for padding token
# Initialize embedding matrix (vocab_size x embed_size)
embedding_matrix = torch.zeros((vocab_size, config["embedding_dim"]))
# Load model
model = BiLSTM(
embedding_matrix=embedding_matrix,
embed_size=config["embedding_dim"],
num_labels=config["num_labels"],
hidden_size=config["hidden_size"],
dropout=config["dropout_rate"]
)
model.load_state_dict(torch.load(model_weights_path, map_location=DEVICE))
model.eval()
model.to(DEVICE)
# Label encoder with provided class names
le = LabelEncoder()
le.classes_ = np.array([
"ADHD", "Acne", "Anxiety", "Bipolar Disorde", "Birth Control",
"Depression", "Diabetes, Type 2", "Emergency Contraception",
"High Blood Pressure", "Insomnia", "Obesity", "Pain",
"Vaginal Yeast Infection", "Weight Loss"
])
return model, tokenizer, le, config
# Predict Condition
def predict_condition(model, tokenizer, le, text, max_length):
sequence = tokenizer.texts_to_sequences([text])
padded_sequence = pad_sequences(sequence, maxlen=max_length)
tensor = torch.tensor(padded_sequence, dtype=torch.long).to(DEVICE)
with torch.no_grad():
logits = model(tensor)
probabilities = softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities, axis=1).item()
predicted_condition = le.inverse_transform([predicted_class])[0]
confidence = probabilities[0][predicted_class].item()
return predicted_condition, confidence
# Recommend Drugs
def recommend_drugs(predicted_condition, df, top_n=5):
# Filter the dataset for the predicted condition
condition_df = df[df['condition'] == predicted_condition]
# Filter for positive reviews (e.g., rating >= 8)
positive_reviews_df = condition_df[condition_df['rating'] >= 8]
# Group by drug and calculate the number of positive reviews and average rating
drug_stats = positive_reviews_df.groupby('drugName').agg(
num_reviews=('rating', 'size'), # Count the number of reviews
avg_rating=('rating', 'mean') # Calculate the average rating
).reset_index()
# Sort by number of reviews and average rating
drug_stats = drug_stats.sort_values(by=['num_reviews', 'avg_rating'], ascending=False)
# Return the top N drugs
return drug_stats.head(top_n)
# Streamlit UI
def main():
st.title("Condition Prediction and Drug Recommendation")
st.write("Enter symptoms below to predict the medical condition and get drug recommendations.")
# Load the dataset
df = pd.read_csv('dataset/custom.csv')
model, tokenizer, le, config = load_model_and_tokenizer()
user_input = st.text_area("Enter your symptoms", "")
if st.button("Predict"):
if user_input.strip():
with st.spinner("Predicting..."):
predicted_condition, confidence = predict_condition(model, tokenizer, le, user_input, config["max_length"])
st.success(f"Predicted Condition: {predicted_condition} (Confidence: {confidence:.2f})")
# Get drug recommendations
top_drugs = recommend_drugs(predicted_condition, df, top_n=5)
# Display recommended drugs
st.write("**Top Recommended Drugs**")
if not top_drugs.empty:
for i, row in top_drugs.iterrows():
st.write(f"- {row['drugName']} (Positive Reviews: {row['num_reviews']}, Avg Rating: {row['avg_rating']:.2f})")
else:
st.write("No drugs found for this condition.")
else:
st.warning("Please enter a review.")
if __name__ == "__main__":
main()