VaneshDev commited on
Commit
86c2782
·
verified ·
1 Parent(s): 7646d30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -39
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from PIL import Image
3
  import torch
4
  from torchvision import models, transforms
 
5
  import logging
6
  import os
7
 
@@ -9,37 +10,68 @@ import os
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
- # ChestX-ray14 condition labels
13
  conditions = [
14
- "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule",
15
- "Pneumonia", "Pneumothorax", "Consolidation", "Edema", "Emphysema", "Fibrosis",
16
- "Pleural Thickening", "Hernia"
 
 
 
17
  ]
18
 
19
- # Load DenseNet121 base model
20
- model = models.densenet121(pretrained=False)
21
- model.classifier = torch.nn.Linear(model.classifier.in_features, len(conditions))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
- model = model.to(device)
 
 
 
 
 
25
  model.eval()
 
 
26
 
27
- # Load CheXNet pre-trained weights
28
- model_path = "xray_model.pth"
29
  if os.path.exists(model_path):
30
  try:
31
- checkpoint = torch.load(model_path, map_location=device)
32
- if "state_dict" in checkpoint:
33
- model.load_state_dict(checkpoint["state_dict"])
34
- else:
35
- model.load_state_dict(checkpoint)
36
- logger.info("✅ Loaded CheXNet model weights.")
37
  except Exception as e:
38
- logger.error(f"Failed to load model weights: {e}")
39
  else:
40
- logger.warning("⚠️ Model file 'xray_model.pth' not found!")
41
 
42
- # Preprocessing for image input
43
  def preprocess_image(image):
44
  transform = transforms.Compose([
45
  transforms.Resize((224, 224)),
@@ -48,7 +80,7 @@ def preprocess_image(image):
48
  ])
49
  return transform(image).unsqueeze(0).to(device)
50
 
51
- # X-ray prediction
52
  def predict_xray(image):
53
  try:
54
  if image is None:
@@ -58,34 +90,78 @@ def predict_xray(image):
58
  with torch.no_grad():
59
  output = model(img_tensor)
60
 
61
- probs = torch.sigmoid(output)[0] # sigmoid for multi-label prediction
62
- result_lines = []
63
 
64
- for i, condition in enumerate(conditions):
65
- confidence = probs[i].item() * 100
66
- if confidence >= 10: # only show confident predictions
67
- result_lines.append(f"<b>{condition}:</b> {confidence:.2f}%")
68
 
69
- if not result_lines:
70
- return "<b>Uncertain:</b> No strong signs of any known condition detected."
 
 
 
 
 
 
 
71
 
72
- return "<br>".join(result_lines)
 
 
 
 
 
 
 
 
73
 
74
  except Exception as e:
75
- logger.error(f"Prediction failed: {e}")
76
- return f"<b>Error:</b> {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- # Gradio interface setup
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def create_interface():
80
  with gr.Blocks() as demo:
81
- gr.Markdown("<h1 style='text-align:center;'>🩻 RadiologyScan AI (CheXNet)</h1>")
82
- with gr.Row():
83
- image_input = gr.Image(label="Upload Chest X-ray", type="pil")
84
- result_output = gr.HTML(label="Diagnosis Result")
85
- gr.Button("Analyze X-ray").click(predict_xray, inputs=image_input, outputs=result_output)
 
 
 
 
 
 
 
 
86
  return demo
87
 
88
- # Run the app
89
  if __name__ == "__main__":
90
  demo = create_interface()
91
  demo.launch(server_port=7860, ssr_mode=False)
 
2
  from PIL import Image
3
  import torch
4
  from torchvision import models, transforms
5
+ import fitz # PyMuPDF
6
  import logging
7
  import os
8
 
 
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ # Condition list
14
  conditions = [
15
+ "Normal", "Pneumonia", "Cancer", "TB", "Other",
16
+ "Coronary Artery Disease", "Aortic Aneurysm", "Stroke", "Peripheral Artery Disease",
17
+ "Brain Tumor", "Alzheimer's Disease", "Multiple Sclerosis", "Epilepsy",
18
+ "COPD", "Lung Cancer", "Pulmonary Embolism",
19
+ "Fractures", "Arthritis", "Osteoporosis",
20
+ "Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
21
  ]
22
 
23
+ # Condition details
24
+ condition_details = {
25
+ "Normal": {"description": "No abnormal signs detected.", "recommendation": "Routine check-ups recommended."},
26
+ "Pneumonia": {"description": "Lung inflammation detected, possibly infectious.", "recommendation": "Seek medical attention for treatment."},
27
+ "Cancer": {"description": "Suspicious masses suggest cancer; further imaging needed.", "recommendation": "Consult an oncologist."},
28
+ "TB": {"description": "Cavitary lesions indicate tuberculosis.", "recommendation": "Immediate medical evaluation required."},
29
+ "Other": {"description": "Unclear abnormality; further investigation needed.", "recommendation": "Consult a radiologist."},
30
+ "Coronary Artery Disease": {"description": "Narrowing of coronary arteries detected.", "recommendation": "Cardiology consultation required."},
31
+ "Aortic Aneurysm": {"description": "Abnormal enlargement of the aorta.", "recommendation": "Vascular surgery evaluation."},
32
+ "Stroke": {"description": "Signs of brain ischemia or hemorrhage.", "recommendation": "Urgent neurological evaluation."},
33
+ "Peripheral Artery Disease": {"description": "Reduced blood flow in peripheral arteries.", "recommendation": "Vascular specialist consultation."},
34
+ "Brain Tumor": {"description": "Abnormal mass in the brain detected.", "recommendation": "Consult a neurosurgeon."},
35
+ "Alzheimer's Disease": {"description": "Signs of neurodegenerative changes.", "recommendation": "Neurology consultation."},
36
+ "Multiple Sclerosis": {"description": "Demyelinating lesions in the CNS.", "recommendation": "Neurology consultation."},
37
+ "Epilepsy": {"description": "Signs of seizure activity.", "recommendation": "Neurology consultation."},
38
+ "COPD": {"description": "Lung damage from COPD observed.", "recommendation": "Pulmonary consultation."},
39
+ "Lung Cancer": {"description": "Malignant lung masses detected.", "recommendation": "Oncology consultation."},
40
+ "Pulmonary Embolism": {"description": "Blockage in pulmonary arteries.", "recommendation": "Urgent medical attention."},
41
+ "Fractures": {"description": "Bone break detected.", "recommendation": "Orthopedic evaluation."},
42
+ "Arthritis": {"description": "Joint inflammation detected.", "recommendation": "Rheumatology consultation."},
43
+ "Osteoporosis": {"description": "Reduced bone density observed.", "recommendation": "Bone health specialist consultation."},
44
+ "Appendicitis": {"description": "Inflammation of the appendix.", "recommendation": "Surgical evaluation."},
45
+ "Gallstones": {"description": "Stones in the gallbladder detected.", "recommendation": "Gastroenterology consultation."},
46
+ "Kidney Stones": {"description": "Stones in the kidneys detected.", "recommendation": "Urology consultation."},
47
+ "Infections": {"description": "Signs of infection observed.", "recommendation": "Infectious disease consultation."},
48
+ "Abdominal Aortic Aneurysm": {"description": "Enlargement of the abdominal aorta.", "recommendation": "Vascular surgery evaluation."},
49
+ "Diverticulitis": {"description": "Inflammation of diverticula in the colon.", "recommendation": "Gastroenterology consultation."}
50
+ }
51
 
52
+ # Load model
53
+ try:
54
+ model = models.densenet121(weights="IMAGENET1K_V1")
55
+ except AttributeError:
56
+ model = models.densenet121(pretrained=True)
57
+
58
+ model.classifier = torch.nn.Linear(model.classifier.in_features, len(conditions))
59
  model.eval()
60
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ model.to(device)
62
 
63
+ # Load model weights if available
64
+ model_path = os.getenv("MODEL_PATH", "xray_model.pth")
65
  if os.path.exists(model_path):
66
  try:
67
+ model.load_state_dict(torch.load(model_path, map_location=device))
68
+ logger.info("Model loaded from file.")
 
 
 
 
69
  except Exception as e:
70
+ logger.warning(f"Failed to load model weights: {e}")
71
  else:
72
+ logger.info("No custom model weights found.")
73
 
74
+ # Image preprocessing
75
  def preprocess_image(image):
76
  transform = transforms.Compose([
77
  transforms.Resize((224, 224)),
 
80
  ])
81
  return transform(image).unsqueeze(0).to(device)
82
 
83
+ # X-ray prediction function with confidence threshold
84
  def predict_xray(image):
85
  try:
86
  if image is None:
 
90
  with torch.no_grad():
91
  output = model(img_tensor)
92
 
93
+ probs = torch.nn.functional.softmax(output, dim=1)[0]
94
+ results = {conditions[i]: probs[i].item() * 100 for i in range(len(conditions))}
95
 
96
+ top_condition = max(results, key=results.get)
97
+ confidence = results[top_condition]
 
 
98
 
99
+ if confidence < 50:
100
+ return f"""
101
+ <div style="font-family:Arial">
102
+ <h3>Prediction: <span style="color:#D62828;">Uncertain</span></h3>
103
+ <p><b>Confidence:</b> {confidence:.2f}%</p>
104
+ <p><b>Note:</b> The model is not confident enough to provide a clear diagnosis.</p>
105
+ <p><b>Recommendation:</b> Please consult a radiologist or upload a better-quality image.</p>
106
+ </div>
107
+ """
108
 
109
+ info = condition_details.get(top_condition, condition_details["Other"])
110
+ return f"""
111
+ <div style="font-family:Arial">
112
+ <h3>Prediction: <span style="color:#2A9D8F;">{top_condition}</span></h3>
113
+ <p><b>Confidence:</b> {confidence:.2f}%</p>
114
+ <p><b>Description:</b> {info['description']}</p>
115
+ <p><b>Recommendation:</b> {info['recommendation']}</p>
116
+ </div>
117
+ """
118
 
119
  except Exception as e:
120
+ logger.error(f"Error in prediction: {e}")
121
+ return f"Error: {str(e)}"
122
+
123
+ # Analyze PDF report
124
+ def analyze_report(file):
125
+ if not file or not file.name.endswith(".pdf"):
126
+ return "Please upload a valid PDF file."
127
+ try:
128
+ doc = fitz.open(file.name)
129
+ text = "".join(page.get_text() for page in doc)
130
+ doc.close()
131
+
132
+ condition, disease, status = "Unclear", "Unknown", "Pending"
133
 
134
+ if "stroke" in text.lower():
135
+ condition, disease, status = "Stroke", "Brain Disorder", "Urgent Care Needed"
136
+ elif "cancer" in text.lower():
137
+ condition, disease, status = "Cancer", "Malignant Growth", "Consult Oncologist"
138
+ elif "fracture" in text.lower():
139
+ condition, disease, status = "Fracture", "Bone Injury", "Orthopedic Attention Required"
140
+
141
+ preview = text[:300] + "..." if text else "No readable content."
142
+ return f"Condition: {condition}\nDisease: {disease}\nStatus: {status}\n\nPreview:\n{preview}"
143
+
144
+ except Exception as e:
145
+ return f"Failed to process PDF: {str(e)}"
146
+
147
+ # Gradio interface
148
  def create_interface():
149
  with gr.Blocks() as demo:
150
+ gr.Markdown("<h1 style='text-align:center;'>🩻 RadiologyScan AI</h1><p style='text-align:center;'>AI-powered X-ray and Report Analysis</p>")
151
+
152
+ with gr.Tabs():
153
+ with gr.TabItem("X-ray Analysis"):
154
+ img_input = gr.Image(label="Upload Chest X-ray", type="pil")
155
+ img_output = gr.HTML()
156
+ gr.Button("Analyze X-ray").click(predict_xray, inputs=img_input, outputs=img_output)
157
+
158
+ with gr.TabItem("Report Analysis"):
159
+ pdf_input = gr.File(label="Upload PDF Report", file_types=[".pdf"])
160
+ pdf_output = gr.Textbox(label="Extracted Summary", lines=10)
161
+ gr.Button("Analyze Report").click(analyze_report, inputs=pdf_input, outputs=pdf_output)
162
+
163
  return demo
164
 
 
165
  if __name__ == "__main__":
166
  demo = create_interface()
167
  demo.launch(server_port=7860, ssr_mode=False)