puneeth1's picture
Create app.py
eeeca54 verified
import torch
import gradio as gr
import pandas as pd
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import joblib
import io, base64
import matplotlib.pyplot as plt
# Load model and tokenizer from the Hub
model = AutoModelForSequenceClassification.from_pretrained("modernbert-disease-classifier")
tokenizer = AutoTokenizer.from_pretrained("modernbert-disease-classifier")
# Load additional assets (drug recommendation data and label encoder)
drug_df = pd.read_csv("drug_recommendation_data.csv")
label_encoder = joblib.load("label_encoder.pkl")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def recommend_drugs(condition, df, top_n=3):
condition_df = df[df['condition'] == condition]
if condition_df.empty:
return f"<p style='color:red;'>No data available for condition: {condition}</p>"
high_rated = condition_df[condition_df['rating'] >= 8]
if high_rated.empty:
return f"<p style='color:orange;'>No high-rated drugs available for condition: {condition}</p>"
drug_stats = high_rated.groupby('drugName').agg(
num_reviews=('rating', 'size'),
avg_rating=('rating', 'mean'),
total_usefulness=('usefulCount', 'sum')
).reset_index()
drug_stats = drug_stats.sort_values(
by=['num_reviews', 'avg_rating', 'total_usefulness'],
ascending=False
).head(top_n)
colors = ["#e6194b", "#3cb44b", "#ffe119", "#4363d8", "#f58231"]
recommendations = "<ul style='list-style-type: none; padding: 0;'>"
for i, row in drug_stats.iterrows():
color = colors[i % len(colors)]
recommendations += (
f"<li style='margin-bottom: 8px;'>"
f"<strong style='color:{color};'>{row['drugName']}</strong> - "
f"Rating: <span style='color:lightgreen;'>{row['avg_rating']:.1f}</span>, "
f"Useful Votes: {row['total_usefulness']:,}"
f"</li>"
)
recommendations += "</ul>"
return recommendations
def predict_and_recommend(text):
try:
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=-1).item()
predicted_condition = label_encoder.inverse_transform([predicted_class])[0]
recommendations = recommend_drugs(predicted_condition, drug_df, top_n=3)
output = f"""
<div style="
font-family:Arial, sans-serif;
background:#2c2c2c;
padding:20px;
border-radius:8px;
color:#f1f1f1;
line-height:1.6;
">
<h2 style="color:#f39c12; margin-bottom: 10px;">Predicted Condition: {predicted_condition}</h2>
<h3 style="color:#16a085; margin-bottom: 15px;">Recommended Drugs:</h3>
{recommendations}
</div>
"""
return output
except Exception as e:
return f"<p style='color:red;'>Error: {str(e)}</p>"
# (Optional) Add additional features such as chatbot or visualization here.
# Create a Gradio interface
interface = gr.Interface(
fn=predict_and_recommend,
inputs=gr.Textbox(label="Patient Review", placeholder="Enter medical review...", lines=4),
outputs=gr.HTML(label="Results"),
title="🩺 Disease Classifier & Drug Recommender",
description="Enter a patient review to predict a medical condition and get drug recommendations.",
theme="JohnSmith9982/small_and_pretty"
)
if __name__ == "__main__":
interface.launch()