RadiologyScanAI / app.py
VaneshDev's picture
Update app.py
897c5ee verified
raw
history blame
10.4 kB
import gradio as gr
from PIL import Image
import torch
from torchvision import models, transforms
import fitz # PyMuPDF for better PDF parsing
import logging
import os
# Set up logging
logging.basicConfig(level=logging.DEBUG) # Set to DEBUG for detailed output
logger = logging.getLogger(__name__)
# Define conditions
conditions = [
"Normal", "Pneumonia", "Cancer", "TB", "Other",
"Coronary Artery Disease", "Aortic Aneurysm", "Stroke", "Peripheral Artery Disease",
"Brain Tumor", "Alzheimer's Disease", "Multiple Sclerosis", "Epilepsy",
"COPD", "Lung Cancer", "Pulmonary Embolism",
"Fractures", "Arthritis", "Osteoporosis",
"Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
]
# Define condition details for all conditions
condition_details = {
"Normal": {"description": "No abnormal signs detected.", "recommendation": "Routine check-ups recommended."},
"Pneumonia": {"description": "Lung inflammation detected, possibly infectious.", "recommendation": "Seek medical attention for treatment."},
"Cancer": {"description": "Suspicious masses suggest cancer; further imaging needed.", "recommendation": "Consult an oncologist."},
"TB": {"description": "Cavitary lesions indicate tuberculosis.", "recommendation": "Immediate medical evaluation required."},
"Other": {"description": "Unclear abnormality; further investigation needed.", "recommendation": "Consult a radiologist."},
"Coronary Artery Disease": {"description": "Narrowing of coronary arteries detected.", "recommendation": "Cardiology consultation required."},
"Aortic Aneurysm": {"description": "Abnormal enlargement of the aorta.", "recommendation": "Vascular surgery evaluation."},
"Stroke": {"description": "Signs of brain ischemia or hemorrhage.", "recommendation": "Urgent neurological evaluation."},
"Peripheral Artery Disease": {"description": "Reduced blood flow in peripheral arteries.", "recommendation": "Vascular specialist consultation."},
"Brain Tumor": {"description": "Abnormal mass in the brain detected.", "recommendation": "Consult a neurosurgeon."},
"Alzheimer's Disease": {"description": "Signs of neurodegenerative changes.", "recommendation": "Neurology consultation."},
"Multiple Sclerosis": {"description": "Demyelinating lesions in the CNS.", "recommendation": "Neurology consultation."},
"Epilepsy": {"description": "Signs of seizure activity.", "recommendation": "Neurology consultation."},
"COPD": {"description": "Lung damage from COPD observed.", "recommendation": "Pulmonary consultation."},
"Lung Cancer": {"description": "Malignant lung masses detected.", "recommendation": "Oncology consultation."},
"Pulmonary Embolism": {"description": "Blockage in pulmonary arteries.", "recommendation": "Urgent medical attention."},
"Fractures": {"description": "Bone break detected.", "recommendation": "Orthopedic evaluation."},
"Arthritis": {"description": "Joint inflammation detected.", "recommendation": "Rheumatology consultation."},
"Osteoporosis": {"description": "Reduced bone density observed.", "recommendation": "Bone health specialist consultation."},
"Appendicitis": {"description": "Inflammation of the appendix.", "recommendation": "Surgical evaluation."},
"Gallstones": {"description": "Stones in the gallbladder detected.", "recommendation": "Gastroenterology consultation."},
"Kidney Stones": {"description": "Stones in the kidneys detected.", "recommendation": "Urology consultation."},
"Infections": {"description": "Signs of infection observed.", "recommendation": "Infectious disease consultation."},
"Abdominal Aortic Aneurysm": {"description": "Enlargement of the abdominal aorta.", "recommendation": "Vascular surgery evaluation."},
"Diverticulitis": {"description": "Inflammation of diverticula in the colon.", "recommendation": "Gastroenterology consultation."}
}
# Load and configure the model
try:
model = models.densenet121(weights="IMAGENET1K_V1")
num_features = model.classifier.in_features
model.classifier = torch.nn.Linear(num_features, len(conditions))
model.eval()
except AttributeError:
model = models.densenet121(pretrained=True)
num_features = model.classifier.in_features
model.classifier = torch.nn.Linear(num_features, len(conditions))
model.eval()
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
logger.info(f"Using device: {device}")
# Load model state if available
model_path = os.getenv("MODEL_PATH", "xray_model.pth")
if os.path.exists(model_path):
try:
model.load_state_dict(torch.load(model_path, map_location=device))
logger.info(f"Loaded model from {model_path}")
except Exception as e:
logger.warning(f"Failed to load model from {model_path}: {str(e)}. Using random weights.")
else:
logger.info("No pre-trained model found. Initializing with random weights. Training required.")
# Define image preprocessing function
def preprocess_image(image):
if not isinstance(image, Image.Image):
logger.error("Invalid image format. Expected PIL Image.")
raise ValueError("Uploaded file is not a valid image.")
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image_tensor = transform(image).unsqueeze(0).to(device)
logger.debug(f"Preprocessed image tensor shape: {image_tensor.shape}")
return image_tensor
# Define prediction function
def predict_xray(image):
try:
if image is None:
return "Error: No image uploaded.", "", ""
image_tensor = preprocess_image(image)
with torch.no_grad():
outputs = model(image_tensor)
probs = torch.nn.functional.softmax(outputs, dim=1)[0]
results = {conditions[i]: float(probs[i].cpu().numpy()) * 100 for i in range(len(conditions))}
most_likely_condition = max(results, key=results.get)
confidence = results[most_likely_condition]
if most_likely_condition not in condition_details:
most_likely_condition = "Other"
confidence = 0.0
summary = f"**Summary**: Based on the X-ray analysis, the most likely diagnosis is: <b>{most_likely_condition}</b> with a confidence of <b>{confidence:.2f}%</b>."
detailed_results = "<ul class='result-list'>"
for condition, prob in results.items():
detailed_results += f"<li><b>{condition}:</b> {prob:.2f}%</li>"
detailed_results += "</ul>"
additional_feedback = f"<div class='feedback-box'><b>Description:</b> {condition_details[most_likely_condition]['description']}<br><b>Recommendation:</b> {condition_details[most_likely_condition]['recommendation']}</div>"
logger.info(f"Prediction: {most_likely_condition} with confidence {confidence:.2f}%")
return summary, detailed_results, additional_feedback
except Exception as e:
logger.error(f"Error in predict_xray: {str(e)}")
return f"Error: {str(e)}", "", ""
# Enhanced function to analyze patient reports (PDFs)
def analyze_report(file):
if not file or not file.name.endswith(".pdf"):
return "Please upload a valid PDF file."
text = ""
patient_condition = "Unclear"
disease = "Unknown"
status = "Pending further tests"
try:
pdf_reader = fitz.open(file.name)
for page in pdf_reader:
text += page.get_text("text")
pdf_reader.close()
if "stroke" in text.lower():
patient_condition = "Stroke"
disease = "Brain Disorder"
status = "Urgent Care Needed"
elif "cancer" in text.lower():
patient_condition = "Cancer"
disease = "Malignant Growth"
status = "Consult Oncologist"
elif "fracture" in text.lower():
patient_condition = "Fracture"
disease = "Bone Injury"
status = "Orthopedic Attention Required"
report_summary = f"Patient's Condition: {patient_condition}\nDisease: {disease}\nCondition Status: {status}\n\nReport Preview: {text[:300]}..." if text else "No readable text found in the PDF."
return report_summary
except Exception as e:
logger.error(f"Error reading PDF: {str(e)}")
return f"Error processing PDF: {str(e)}"
# Gradio Interface with Tabs
def create_interface():
logger.debug("Initializing Gradio interface")
# Minimal CSS to avoid styling issues
custom_css = """
.title { font-size: 30px; text-align: center; color: #4C6A92; }
.subtitle { text-align: center; color: #666; }
"""
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("<h1 class='title'>RadiologyScan AI</h1>")
gr.Markdown("<p class='subtitle'>AI-powered analysis for X-rays and patient reports</p>")
with gr.Tabs():
with gr.TabItem("X-ray Analysis"):
image_input = gr.Image(label="Upload X-ray", type="pil")
output_summary = gr.HTML(label="Summary")
output_details = gr.HTML(label="Detailed Results")
output_feedback = gr.HTML(label="Additional Feedback")
gr.Button("Analyze X-ray").click(
fn=predict_xray,
inputs=image_input,
outputs=[output_summary, output_details, output_feedback]
)
with gr.TabItem("Report Analysis"):
file_input = gr.File(label="Upload Patient Report (PDF)", file_count="single")
output_report = gr.Textbox(label="Report Summary", interactive=False)
gr.Button("Analyze Report").click(
fn=analyze_report,
inputs=file_input,
outputs=output_report
)
logger.debug("Gradio interface initialized")
return demo
if __name__ == "__main__":
logger.debug("Starting Gradio application")
try:
demo = create_interface()
demo.launch(server_port=7860, ssr_mode=False) # Explicit port, disable SSR
logger.debug("Gradio application launched")
except Exception as e:
logger.error(f"Failed to launch Gradio application: {str(e)}")