Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
# Install necessary libraries
|
| 2 |
-
#
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
from transformers import pipeline
|
| 6 |
import pandas as pd
|
|
|
|
| 7 |
import pdfplumber
|
| 8 |
import torch
|
| 9 |
-
from torchvision import transforms
|
| 10 |
-
from PIL import Image
|
| 11 |
import timm
|
|
|
|
| 12 |
|
| 13 |
# Load pre-trained model for zero-shot classification
|
| 14 |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
|
@@ -20,11 +20,44 @@ image_model.eval()
|
|
| 20 |
# Initialize patient database
|
| 21 |
patients_db = []
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# Function to register patients
|
| 24 |
def register_patient(name, age, gender):
|
| 25 |
-
if not name or age <= 0 or gender not in ["Male", "Female", "Other"]:
|
| 26 |
-
return "❌ Invalid input. Please check the details and try again."
|
| 27 |
-
|
| 28 |
patient_id = len(patients_db) + 1
|
| 29 |
patients_db.append({
|
| 30 |
"ID": patient_id,
|
|
@@ -35,74 +68,68 @@ def register_patient(name, age, gender):
|
|
| 35 |
"Diagnosis": "",
|
| 36 |
"Action Plan": "",
|
| 37 |
"Medications": "",
|
|
|
|
| 38 |
"Tests": ""
|
| 39 |
})
|
| 40 |
return f"✅ Patient {name} registered successfully. Patient ID: {patient_id}"
|
| 41 |
|
| 42 |
# Function to analyze text reports
|
| 43 |
def analyze_report(patient_id, report_text):
|
| 44 |
-
|
| 45 |
-
return "❌ Report text cannot be empty."
|
| 46 |
-
|
| 47 |
-
candidate_labels = ["anemia", "viral infection", "liver disease", "kidney disease", "diabetes"]
|
| 48 |
result = classifier(report_text, candidate_labels)
|
| 49 |
diagnosis = result['labels'][0]
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# Store diagnosis in the database
|
| 53 |
for patient in patients_db:
|
| 54 |
if patient["ID"] == patient_id:
|
| 55 |
patient["Diagnosis"] = diagnosis
|
| 56 |
patient["Action Plan"] = action_plan
|
|
|
|
|
|
|
| 57 |
break
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
| 62 |
|
| 63 |
# Function to extract text from PDF reports
|
| 64 |
def extract_pdf_report(pdf):
|
| 65 |
text = ""
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
text += page.extract_text()
|
| 70 |
-
except Exception as e:
|
| 71 |
-
return f"❌ Error extracting text from PDF: {str(e)}"
|
| 72 |
-
|
| 73 |
return text
|
| 74 |
|
| 75 |
# Function to analyze uploaded images (X-ray/CT-scan)
|
| 76 |
def analyze_image(patient_id, img):
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
else:
|
| 100 |
-
return "❌ Patient ID not found."
|
| 101 |
-
|
| 102 |
-
return f"🔍 Diagnosis from image: {diagnosis}"
|
| 103 |
-
|
| 104 |
-
except Exception as e:
|
| 105 |
-
return f"❌ Error analyzing image: {str(e)}"
|
| 106 |
|
| 107 |
# Function to display the dashboard
|
| 108 |
def show_dashboard():
|
|
@@ -170,9 +197,9 @@ with gr.Blocks() as demo:
|
|
| 170 |
with gr.TabItem("Analyze Report (PDF)"):
|
| 171 |
pdf_report_interface.render()
|
| 172 |
with gr.TabItem("Analyze Image (X-ray/CT)"):
|
| 173 |
-
|
| 174 |
image_interface.render()
|
| 175 |
with gr.TabItem("Dashboard"):
|
| 176 |
dashboard_interface.render()
|
| 177 |
|
| 178 |
demo.launch(share=True)
|
|
|
|
|
|
| 1 |
# Install necessary libraries
|
| 2 |
+
#!pip install gradio transformers pandas PyPDF2 pdfplumber torch torchvision timm sentencepiece
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
from transformers import pipeline
|
| 6 |
import pandas as pd
|
| 7 |
+
import PyPDF2
|
| 8 |
import pdfplumber
|
| 9 |
import torch
|
|
|
|
|
|
|
| 10 |
import timm
|
| 11 |
+
from PIL import Image
|
| 12 |
|
| 13 |
# Load pre-trained model for zero-shot classification
|
| 14 |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
|
|
|
| 20 |
# Initialize patient database
|
| 21 |
patients_db = []
|
| 22 |
|
| 23 |
+
# Disease and medication mapping
|
| 24 |
+
disease_details = {
|
| 25 |
+
"anemia": {
|
| 26 |
+
"medication": "Iron supplements (e.g., Ferrous sulfate)",
|
| 27 |
+
"precaution": "Increase intake of iron-rich foods like spinach and red meat."
|
| 28 |
+
},
|
| 29 |
+
"viral infection": {
|
| 30 |
+
"medication": "Antiviral drugs (e.g., Oseltamivir for flu)",
|
| 31 |
+
"precaution": "Rest, stay hydrated, and avoid close contact with others."
|
| 32 |
+
},
|
| 33 |
+
"liver disease": {
|
| 34 |
+
"medication": "Hepatoprotective drugs (e.g., Ursodeoxycholic acid)",
|
| 35 |
+
"precaution": "Avoid alcohol and maintain a balanced diet."
|
| 36 |
+
},
|
| 37 |
+
"kidney disease": {
|
| 38 |
+
"medication": "Angiotensin-converting enzyme inhibitors (e.g., Lisinopril)",
|
| 39 |
+
"precaution": "Monitor salt intake and stay hydrated."
|
| 40 |
+
},
|
| 41 |
+
"diabetes": {
|
| 42 |
+
"medication": "Metformin or insulin therapy",
|
| 43 |
+
"precaution": "Follow a low-sugar diet and exercise regularly."
|
| 44 |
+
},
|
| 45 |
+
"hypertension": {
|
| 46 |
+
"medication": "Antihypertensive drugs (e.g., Amlodipine)",
|
| 47 |
+
"precaution": "Reduce salt intake and manage stress."
|
| 48 |
+
},
|
| 49 |
+
"COVID-19": {
|
| 50 |
+
"medication": "Supportive care, antiviral drugs (e.g., Remdesivir in severe cases)",
|
| 51 |
+
"precaution": "Follow isolation protocols, wear a mask, and stay hydrated."
|
| 52 |
+
},
|
| 53 |
+
"pneumonia": {
|
| 54 |
+
"medication": "Antibiotics (e.g., Amoxicillin) if bacterial",
|
| 55 |
+
"precaution": "Rest, avoid smoking, and stay hydrated."
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
# Function to register patients
|
| 60 |
def register_patient(name, age, gender):
|
|
|
|
|
|
|
|
|
|
| 61 |
patient_id = len(patients_db) + 1
|
| 62 |
patients_db.append({
|
| 63 |
"ID": patient_id,
|
|
|
|
| 68 |
"Diagnosis": "",
|
| 69 |
"Action Plan": "",
|
| 70 |
"Medications": "",
|
| 71 |
+
"Precautions": "",
|
| 72 |
"Tests": ""
|
| 73 |
})
|
| 74 |
return f"✅ Patient {name} registered successfully. Patient ID: {patient_id}"
|
| 75 |
|
| 76 |
# Function to analyze text reports
|
| 77 |
def analyze_report(patient_id, report_text):
|
| 78 |
+
candidate_labels = list(disease_details.keys())
|
|
|
|
|
|
|
|
|
|
| 79 |
result = classifier(report_text, candidate_labels)
|
| 80 |
diagnosis = result['labels'][0]
|
| 81 |
+
|
| 82 |
+
# Fetch medication and precaution
|
| 83 |
+
medication = disease_details[diagnosis]["medication"]
|
| 84 |
+
precaution = disease_details[diagnosis]["precaution"]
|
| 85 |
+
action_plan = f"You might have {diagnosis}. Please consult a doctor for confirmation."
|
| 86 |
|
| 87 |
# Store diagnosis in the database
|
| 88 |
for patient in patients_db:
|
| 89 |
if patient["ID"] == patient_id:
|
| 90 |
patient["Diagnosis"] = diagnosis
|
| 91 |
patient["Action Plan"] = action_plan
|
| 92 |
+
patient["Medications"] = medication
|
| 93 |
+
patient["Precautions"] = precaution
|
| 94 |
break
|
| 95 |
+
|
| 96 |
+
return (f"🔍 Diagnosis: {diagnosis}\n"
|
| 97 |
+
f"🩺 Medications: {medication}\n"
|
| 98 |
+
f"⚠️ Precautions: {precaution}\n"
|
| 99 |
+
f"💡 {action_plan}")
|
| 100 |
|
| 101 |
# Function to extract text from PDF reports
|
| 102 |
def extract_pdf_report(pdf):
|
| 103 |
text = ""
|
| 104 |
+
with pdfplumber.open(pdf.name) as pdf_file:
|
| 105 |
+
for page in pdf_file.pages:
|
| 106 |
+
text += page.extract_text()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
return text
|
| 108 |
|
| 109 |
# Function to analyze uploaded images (X-ray/CT-scan)
|
| 110 |
def analyze_image(patient_id, img):
|
| 111 |
+
image = Image.open(img).convert('RGB')
|
| 112 |
+
transform = torch.nn.Sequential(
|
| 113 |
+
torch.nn.Upsample(size=(224, 224)),
|
| 114 |
+
torch.nn.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 115 |
+
)
|
| 116 |
+
image_tensor = transform(torch.unsqueeze(torch.tensor(image), 0))
|
| 117 |
+
|
| 118 |
+
# Run the image through the model (for simplicity, assuming ResNet50 output)
|
| 119 |
+
output = image_model(image_tensor)
|
| 120 |
+
_, predicted = torch.max(output, 1)
|
| 121 |
+
|
| 122 |
+
# Map prediction to a label
|
| 123 |
+
labels = {0: "Normal", 1: "Pneumonia", 2: "Liver Disorder", 3: "COVID-19"}
|
| 124 |
+
diagnosis = labels.get(predicted.item(), "Unknown")
|
| 125 |
+
|
| 126 |
+
# Store diagnosis in the database
|
| 127 |
+
for patient in patients_db:
|
| 128 |
+
if patient["ID"] == patient_id:
|
| 129 |
+
patient["Diagnosis"] = diagnosis
|
| 130 |
+
break
|
| 131 |
+
|
| 132 |
+
return f"🔍 Diagnosis from image: {diagnosis}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
# Function to display the dashboard
|
| 135 |
def show_dashboard():
|
|
|
|
| 197 |
with gr.TabItem("Analyze Report (PDF)"):
|
| 198 |
pdf_report_interface.render()
|
| 199 |
with gr.TabItem("Analyze Image (X-ray/CT)"):
|
|
|
|
| 200 |
image_interface.render()
|
| 201 |
with gr.TabItem("Dashboard"):
|
| 202 |
dashboard_interface.render()
|
| 203 |
|
| 204 |
demo.launch(share=True)
|
| 205 |
+
|