Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| from utils.ai71_utils import get_ai71_response | |
| from datetime import datetime | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import supervision as sv | |
| import matplotlib.pyplot as plt | |
| import io | |
| import os | |
| from inference_sdk import InferenceHTTPClient | |
| from bs4 import BeautifulSoup | |
| import tensorflow as tf | |
| import pandas as pd | |
| from sklearn.feature_extraction.text import CountVectorizer | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.metrics import accuracy_score, classification_report | |
| import nltk | |
| import re | |
| from nltk.tokenize import word_tokenize | |
| from nltk.corpus import stopwords | |
| import joblib # Import joblib for loading the logistic regression model | |
| # --- Download NLTK 'punkt' resource if not already present --- | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| nltk.download('punkt') | |
| try: | |
| nltk.data.find('corpora/stopwords') | |
| except LookupError: | |
| nltk.download('stopwords') | |
| # --- Preprocess text function (moved outside session state) --- | |
| def preprocess_text(text): | |
| # Convert to lowercase | |
| text = text.lower() | |
| cleaned_text = re.sub(r'[^a-zA-Z0-9\s\,]', ' ', text) | |
| # Tokenize text | |
| tokens = word_tokenize(cleaned_text) | |
| # Remove stop words | |
| stop_words = set(stopwords.words('english')) | |
| tokens = [word for word in tokens if word not in stop_words] | |
| # Rejoin tokens into a single string | |
| cleaned_text = ' '.join(tokens) | |
| return cleaned_text | |
| st.title("Medi Scape Dashboard") | |
| # --- Session State Initialization --- | |
| if 'disease_model' not in st.session_state: | |
| try: | |
| # Assuming all models are in the root directory of your Hugging Face Space | |
| model_path = 'FINAL_MODEL.h5' | |
| st.session_state.disease_model = tf.keras.models.load_model(model_path) | |
| print("Disease model loaded successfully!") | |
| except FileNotFoundError: | |
| st.error("Disease classification model not found. Please ensure 'FINAL_MODEL.h5' is in the correct directory.") | |
| st.session_state.disease_model = None | |
| except PermissionError: | |
| st.error("Permission error accessing 'model.weights.h5'. Please ensure the file is not being used by another process.") | |
| st.session_state.disease_model = None | |
| # Load the vectorizer | |
| if 'vectorizer' not in st.session_state: | |
| try: | |
| vectorizer_path = "vectorizer.pkl" | |
| st.session_state.vectorizer = pd.read_pickle(vectorizer_path) | |
| print("Vectorizer loaded successfully!") | |
| except FileNotFoundError: | |
| st.error("Vectorizer file not found. Please ensure 'vectorizer.pkl' is in the same directory as this app.") | |
| st.session_state.vectorizer = None | |
| except Exception as e: | |
| st.error(f"An error occurred while loading the vectorizer: {e}") | |
| st.session_state.vectorizer = None | |
| # Load the logistic regression model using joblib | |
| if 'model_llm' not in st.session_state: | |
| try: | |
| llm_model_path = "logistic_regression_model.pkl" | |
| st.session_state.model_llm = joblib.load(llm_model_path) | |
| print("LLM model loaded successfully!") | |
| except FileNotFoundError: | |
| st.error("LLM model file not found. Please ensure 'logistic_regression_model.pkl' is in the correct directory.") | |
| st.session_state.model_llm = None | |
| except Exception as e: | |
| st.error(f"An error occurred while loading the LLM model: {e}") | |
| st.session_state.model_llm = None | |
| # --- End of Session State Initialization --- | |
| # Load the disease classification model (outside session state for this example) | |
| try: | |
| model_path = 'FINAL_MODEL.h5' | |
| disease_model = tf.keras.models.load_model(model_path) | |
| except FileNotFoundError: | |
| st.error("Disease classification model not found. Please ensure 'FINAL_MODEL.h5' is in the correct directory.") | |
| disease_model = None | |
| except PermissionError: | |
| st.error("Permission error accessing 'model.weights.h5'. Please ensure the file is not being used by another process.") | |
| disease_model = None | |
| # Sidebar Navigation | |
| st.sidebar.title("Navigation") | |
| page = st.sidebar.radio("Go to", ["Home", "AI Chatbot Diagnosis", "Drug Identification", "Disease Detection", "Outbreak Alert"]) | |
| # Access secrets using st.secrets | |
| if "INFERENCE_API_URL" not in st.secrets or "INFERENCE_API_KEY" not in st.secrets: | |
| st.error("Please make sure to set your secrets in the Streamlit secrets settings.") | |
| else: | |
| # Initialize the Inference Client | |
| CLIENT = InferenceHTTPClient( | |
| api_url=st.secrets["INFERENCE_API_URL"], | |
| api_key=st.secrets["INFERENCE_API_KEY"] | |
| ) | |
| # Function to preprocess the image | |
| def preprocess_image(image_path): | |
| # Load the image | |
| image = cv2.imread(image_path) | |
| # Convert to grayscale | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| # Remove noise | |
| blurred = cv2.GaussianBlur(gray, (5, 5), 0) | |
| # Thresholding/Binarization | |
| _, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| # Dilation and Erosion | |
| kernel = np.ones((1, 1), np.uint8) | |
| dilated = cv2.dilate(binary, kernel, iterations=1) | |
| eroded = cv2.erode(dilated, kernel, iterations=1) | |
| # Edge detection | |
| edges = cv2.Canny(eroded, 100, 200) | |
| # Deskewing | |
| coords = np.column_stack(np.where(edges > 0)) | |
| angle = cv2.minAreaRect(coords)[-1] | |
| if angle < -45: | |
| angle = -(90 + angle) | |
| else: | |
| angle = -angle | |
| (h, w) = edges.shape[:2] | |
| center = (w // 2, h // 2) | |
| M = cv2.getRotationMatrix2D(center, angle, 1.0) | |
| deskewed = cv2.warpAffine(edges, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE) | |
| # Find contours | |
| contours, _ = cv2.findContours(deskewed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| # Draw contours on the original image | |
| contour_image = image.copy() | |
| cv2.drawContours(contour_image, contours, -1, (0, 255, 0), 2) | |
| return contour_image | |
| def get_x1(detection): | |
| return detection.xyxy[0][0] | |
| # --- Prediction function (using session state) --- | |
| def predict_disease(symptoms): | |
| if st.session_state.vectorizer is not None and st.session_state.model_llm is not None: | |
| preprocessed_symptoms = preprocess_text(symptoms) | |
| symptoms_vectorized = st.session_state.vectorizer.transform([preprocessed_symptoms]) | |
| # Use the loaded model for prediction | |
| prediction = st.session_state.model_llm.predict(symptoms_vectorized) # Use the model, not the array | |
| return prediction[0] # Extract the string value from the array | |
| else: | |
| st.error("Unable to make prediction. Vectorizer or LLM model is not loaded.") | |
| return None | |
| # --- New function to analyze X-ray with LLM --- | |
| def analyze_xray_with_llm(predicted_class): | |
| prompt = f""" | |
| Based on a chest X-ray analysis, the predicted condition is {predicted_class}. | |
| Please provide a concise summary of this condition, including: | |
| - A brief description of the condition. | |
| - Common symptoms associated with it. | |
| - Potential causes. | |
| - General treatment approaches. | |
| - Any other relevant information for a patient. | |
| """ | |
| llm_response = get_ai71_response(prompt) | |
| st.write("## LLM Analysis of X-ray Results:") | |
| st.write(llm_response) | |
| # --- Functions for Symptom Detection --- | |
| def precaution(label): | |
| dataset_precau = pd.read_csv("disease_precaution.csv", encoding='latin1') # Make sure this file is in the same directory | |
| label = str(label).lower() | |
| dataset_precau["Disease"] = dataset_precau["Disease"].str.lower() | |
| filtered_precautions = dataset_precau[dataset_precau["Disease"] == label] | |
| if not filtered_precautions.empty: | |
| precautions = filtered_precautions[["Precaution_1", "Precaution_2", "Precaution_3", "Precaution_4"]] | |
| precautions_list = precautions.values.flatten().tolist() # Flatten the DataFrame to a list of strings | |
| return "\n".join(f"- {precaution}" for precaution in precautions_list) # Join the list into a single string with bullet points | |
| else: | |
| return "No precautions found." | |
| def occurance(label): | |
| dataset_occur = pd.read_csv("disease_riskFactors.csv", encoding='latin1') | |
| label = str(label).lower() | |
| dataset_occur["DNAME"] = dataset_occur["DNAME"].str.lower() | |
| filtered_occurrence = dataset_occur[dataset_occur["DNAME"] == label] | |
| occurrences = filtered_occurrence["OCCUR"].tolist() # Convert Series to list | |
| if occurrences: | |
| return "\n".join(occurrences) # Join the list into a single string with newlines | |
| else: | |
| return "No occurrences found." | |
| if page == "Home": | |
| st.markdown("## Welcome to Medi Scape") | |
| st.write("Medi Scape is an AI-powered healthcare application designed to streamline the process of understanding and managing medical information. It leverages advanced AI models to provide features such as prescription analysis, disease detection from chest X-rays, and symptom-based diagnosis assistance.") | |
| st.markdown("## Features") | |
| st.write("Medi Scape provides various AI-powered tools for remote healthcare, including:") | |
| features = [ | |
| "**AI Chatbot Diagnosis:** Interact with an AI chatbot for preliminary diagnosis and medical information.", | |
| "**Drug Identification:** Upload a prescription image to identify medications and access relevant details.", | |
| "**Doctor's Handwriting Identification:** Our system can accurately recognize and process doctor's handwriting.", | |
| "**Disease Detection:** Upload a chest X-ray image to detect potential diseases.", | |
| "**Outbreak Alert:** Stay informed about potential disease outbreaks in your area." | |
| ] | |
| for feature in features: | |
| st.markdown(f"- {feature}") | |
| st.markdown("## How it Works") | |
| steps = [ | |
| "**Upload:** You can upload a prescription image for drug identification or a chest X-ray image for disease detection.", | |
| "**Process:** Our AI models will analyze the image and extract relevant information.", | |
| "**Results:** You will receive identified drug names, uses, side effects, and more, or a potential disease diagnosis." | |
| ] | |
| for i, step in enumerate(steps, 1): | |
| st.markdown(f"{i}. {step}") | |
| st.markdown("## Key Features") | |
| key_features = [ | |
| "**AI-Powered:** Leverages advanced AI models for accurate analysis and diagnosis.", | |
| "**User-Friendly:** Simple and intuitive interface for easy navigation and interaction.", | |
| "**Secure:** Your data is protected and handled with confidentiality." | |
| ] | |
| for feature in key_features: | |
| st.markdown(f"- {feature}") | |
| st.markdown("Please use the sidebar to navigate to different features.") | |
| elif page == "AI Chatbot Diagnosis": | |
| st.write("Enter your symptoms separated by commas:") | |
| symptoms_input = st.text_area("Symptoms:") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("Diagnose with Regression Model"): | |
| if symptoms_input: | |
| # --- Pipeline 1 Implementation --- | |
| # 1. Symptom Input (already done with st.text_area) | |
| # 2. Regression Prediction | |
| regression_prediction = predict_disease(symptoms_input) | |
| if regression_prediction is not None: | |
| st.write("## Logistic Regression Prediction:") | |
| st.write(regression_prediction) | |
| st.write("## Precautions:") | |
| precautions_names = precaution(regression_prediction) | |
| st.write(precautions_names) | |
| st.write("## Occurrence:") | |
| occurance_name = occurance(regression_prediction) | |
| st.write(occurance_name) | |
| else: | |
| st.write("Please enter your symptoms.") | |
| with col2: | |
| if st.button("Diagnose with LLM"): | |
| if symptoms_input: | |
| # --- Pipeline 2 Implementation (LLM Only) --- | |
| prompt = f"""The user is experiencing the following symptoms: {symptoms_input}. | |
| Based on these symptoms, provide a detailed explanation of possible conditions, including | |
| potential causes, common symptoms, and general treatment approaches. Also, suggest when | |
| a patient should consult a doctor.""" | |
| llm_response = get_ai71_response(prompt) | |
| st.write("## LLM Diagnosis:") | |
| st.write(llm_response) | |
| else: | |
| st.write("Please enter your symptoms.") | |
| elif page == "Drug Identification": | |
| st.write("Upload a prescription image for drug identification.") | |
| uploaded_file = st.file_uploader("Upload prescription", type=["png", "jpg", "jpeg"]) | |
| if uploaded_file is not None: | |
| # Display the uploaded image | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Uploaded Prescription", use_column_width=True) | |
| if st.button("Process Prescription"): | |
| # Save the image to a temporary file | |
| temp_image_path = "temp_image.jpg" | |
| image.save(temp_image_path) | |
| # Preprocess the image | |
| preprocessed_image = preprocess_image(temp_image_path) | |
| # Perform inference | |
| result_doch1 = CLIENT.infer(preprocessed_image, model_id="doctor-s-handwriting/1") | |
| # Extract labels and detections | |
| labels = [item["class"] for item in result_doch1["predictions"]] | |
| detections = sv.Detections.from_inference(result_doch1) | |
| # Sort detections and labels | |
| sorted_indices = sorted(range(len(detections)), key=lambda i: get_x1(detections[i])) | |
| sorted_detections = [detections[i] for i in sorted_indices] | |
| sorted_labels = [labels[i] for i in sorted_indices] | |
| # Convert list to string | |
| resulting_string = ''.join(sorted_labels) | |
| # Display results | |
| st.subheader("Processed Prescription") | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) | |
| # Plot bounding boxes | |
| image_with_boxes = preprocessed_image.copy() | |
| for detection in sorted_detections: | |
| x1, y1, x2, y2 = detection.xyxy[0] | |
| cv2.rectangle(image_with_boxes, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 2) | |
| ax1.imshow(cv2.cvtColor(image_with_boxes, cv2.COLOR_BGR2RGB)) | |
| ax1.set_title("Bounding Boxes") | |
| ax1.axis('off') | |
| # Plot labels | |
| image_with_labels = preprocessed_image.copy() | |
| for i, detection in enumerate(sorted_detections): | |
| x1, y1, x2, y2 = detection.xyxy[0] | |
| label = sorted_labels[i] | |
| cv2.putText(image_with_labels, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) | |
| ax2.imshow(cv2.cvtColor(image_with_labels, cv2.COLOR_BGR2RGB)) | |
| ax2.set_title("Labels") | |
| ax2.axis('off') | |
| st.pyplot(fig) | |
| st.write("Extracted Text from Prescription:", resulting_string) | |
| # Prepare prompt for LLM | |
| prompt = f"""Analyze the following prescription text: | |
| {resulting_string} | |
| Please provide: | |
| 1. Identified drug name(s) | |
| 2. Full name of each identified drug | |
| 3. Primary uses of each drug | |
| 4. Common side effects | |
| 5. Recommended dosage (if identifiable from the text) | |
| 6. Any warnings or precautions | |
| 7. Potential interactions with other medications (if multiple drugs are identified) | |
| 8. Any additional relevant information for the patient | |
| If any part of the prescription is unclear or seems incomplete, please mention that and provide information about possible interpretations or matches. Always emphasize the importance of consulting a healthcare professional for accurate interpretation and advice.""" | |
| # Get LLM response | |
| llm_response = get_ai71_response(prompt) | |
| st.subheader("AI Analysis of the Prescription") | |
| st.write(llm_response) | |
| # Remove the temporary image file | |
| os.remove(temp_image_path) | |
| else: | |
| st.info("Please upload a prescription image to proceed.") | |
| elif page == "Disease Detection": | |
| st.write("Upload a chest X-ray image for disease detection.") | |
| uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_image is not None and disease_model is not None: # Use disease_model directly | |
| # Display the image | |
| img_opened = Image.open(uploaded_image).convert('RGB') | |
| image_pred = np.array(img_opened) | |
| image_pred = cv2.resize(image_pred, (150, 150)) | |
| # Convert the image to a numpy array | |
| image_pred = np.array(image_pred) | |
| # Rescale the image (if the model was trained with rescaling) | |
| image_pred = image_pred / 255.0 | |
| # Add an extra dimension to match the input shape (1, 150, 150, 3) | |
| image_pred = np.expand_dims(image_pred, axis=0) | |
| # Predict using the model | |
| prediction = disease_model.predict(image_pred) # Use disease_model directly | |
| # Get the predicted class | |
| predicted_ = np.argmax(prediction) | |
| # Decode the prediction | |
| if predicted_ == 0: | |
| predicted_class = "Covid" | |
| elif predicted_ == 1: | |
| predicted_class = "Normal Chest X-ray" | |
| else: | |
| predicted_class = "Pneumonia" | |
| st.image(image_pred, caption='Input image by user', use_column_width=True) | |
| st.write("Prediction Classes for different types:") | |
| st.write("COVID: 0") | |
| st.write("Normal Chest X-ray: 1") | |
| st.write("Pneumonia: 2") | |
| st.write("\n") | |
| st.write("DETECTED DISEASE DISPLAY") | |
| st.write(f"Predicted Class : {predicted_}") | |
| st.write(predicted_class) | |
| # Analyze X-ray results with LLM | |
| analyze_xray_with_llm(predicted_class) | |
| else: | |
| st.write("Please upload an image file or ensure the disease model is loaded.") | |
| elif page == "Outbreak Alert": | |
| st.markdown("## **Disease Outbreak News (from WHO)**") | |
| # Fetch WHO news page | |
| url = "https://www.who.int/news-room/events" | |
| response = requests.get(url) | |
| response.raise_for_status() # Raise an exception for bad status codes | |
| soup = BeautifulSoup(response.content, 'html.parser') | |
| # Find news articles (adjust selectors if WHO website changes) | |
| articles = soup.find_all('div', class_='list-view--item') | |
| for article in articles[:5]: # Display the top 5 news articles | |
| title_element = article.find('a', class_='link-container') | |
| if title_element: | |
| title = title_element.text.strip() | |
| link = title_element['href'] | |
| date_element = article.find('span', class_='date') | |
| date = date_element.text.strip() if date_element else "Date not found" | |
| # Format date | |
| date_parts = date.split() | |
| if len(date_parts) >= 3: | |
| try: | |
| formatted_date = datetime.strptime(date, "%d %B %Y").strftime("%Y-%m-%d") | |
| except ValueError: | |
| formatted_date = date # Keep the original date if formatting fails | |
| else: | |
| formatted_date = date | |
| # Display news item in a card-like container | |
| with st.container(): | |
| st.markdown(f"**{formatted_date}**") | |
| st.markdown(f"[{title}]({link})") | |
| st.markdown("---") | |
| else: | |
| st.write("Could not find article details.") | |
| # Auto-scroll to the bottom of the chat container | |
| st.markdown( | |
| """ | |
| <script> | |
| const chatContainer = document.querySelector('.st-chat-container'); | |
| if (chatContainer) { | |
| chatContainer.scrollTop = chatContainer.scrollHeight; | |
| } | |
| </script> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |