VaneshDev commited on
Commit
7646d30
·
verified ·
1 Parent(s): 8d06626

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -115
app.py CHANGED
@@ -2,7 +2,6 @@ import gradio as gr
2
  from PIL import Image
3
  import torch
4
  from torchvision import models, transforms
5
- import fitz # PyMuPDF
6
  import logging
7
  import os
8
 
@@ -10,68 +9,37 @@ import os
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
- # Condition list
14
  conditions = [
15
- "Normal", "Pneumonia", "Cancer", "TB", "Other",
16
- "Coronary Artery Disease", "Aortic Aneurysm", "Stroke", "Peripheral Artery Disease",
17
- "Brain Tumor", "Alzheimer's Disease", "Multiple Sclerosis", "Epilepsy",
18
- "COPD", "Lung Cancer", "Pulmonary Embolism",
19
- "Fractures", "Arthritis", "Osteoporosis",
20
- "Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
21
  ]
22
 
23
- # Condition details
24
- condition_details = {
25
- "Normal": {"description": "No abnormal signs detected.", "recommendation": "Routine check-ups recommended."},
26
- "Pneumonia": {"description": "Lung inflammation detected, possibly infectious.", "recommendation": "Seek medical attention for treatment."},
27
- "Cancer": {"description": "Suspicious masses suggest cancer; further imaging needed.", "recommendation": "Consult an oncologist."},
28
- "TB": {"description": "Cavitary lesions indicate tuberculosis.", "recommendation": "Immediate medical evaluation required."},
29
- "Other": {"description": "Unclear abnormality; further investigation needed.", "recommendation": "Consult a radiologist."},
30
- "Coronary Artery Disease": {"description": "Narrowing of coronary arteries detected.", "recommendation": "Cardiology consultation required."},
31
- "Aortic Aneurysm": {"description": "Abnormal enlargement of the aorta.", "recommendation": "Vascular surgery evaluation."},
32
- "Stroke": {"description": "Signs of brain ischemia or hemorrhage.", "recommendation": "Urgent neurological evaluation."},
33
- "Peripheral Artery Disease": {"description": "Reduced blood flow in peripheral arteries.", "recommendation": "Vascular specialist consultation."},
34
- "Brain Tumor": {"description": "Abnormal mass in the brain detected.", "recommendation": "Consult a neurosurgeon."},
35
- "Alzheimer's Disease": {"description": "Signs of neurodegenerative changes.", "recommendation": "Neurology consultation."},
36
- "Multiple Sclerosis": {"description": "Demyelinating lesions in the CNS.", "recommendation": "Neurology consultation."},
37
- "Epilepsy": {"description": "Signs of seizure activity.", "recommendation": "Neurology consultation."},
38
- "COPD": {"description": "Lung damage from COPD observed.", "recommendation": "Pulmonary consultation."},
39
- "Lung Cancer": {"description": "Malignant lung masses detected.", "recommendation": "Oncology consultation."},
40
- "Pulmonary Embolism": {"description": "Blockage in pulmonary arteries.", "recommendation": "Urgent medical attention."},
41
- "Fractures": {"description": "Bone break detected.", "recommendation": "Orthopedic evaluation."},
42
- "Arthritis": {"description": "Joint inflammation detected.", "recommendation": "Rheumatology consultation."},
43
- "Osteoporosis": {"description": "Reduced bone density observed.", "recommendation": "Bone health specialist consultation."},
44
- "Appendicitis": {"description": "Inflammation of the appendix.", "recommendation": "Surgical evaluation."},
45
- "Gallstones": {"description": "Stones in the gallbladder detected.", "recommendation": "Gastroenterology consultation."},
46
- "Kidney Stones": {"description": "Stones in the kidneys detected.", "recommendation": "Urology consultation."},
47
- "Infections": {"description": "Signs of infection observed.", "recommendation": "Infectious disease consultation."},
48
- "Abdominal Aortic Aneurysm": {"description": "Enlargement of the abdominal aorta.", "recommendation": "Vascular surgery evaluation."},
49
- "Diverticulitis": {"description": "Inflammation of diverticula in the colon.", "recommendation": "Gastroenterology consultation."}
50
- }
51
-
52
- # Load model
53
- try:
54
- model = models.densenet121(weights="IMAGENET1K_V1")
55
- except AttributeError:
56
- model = models.densenet121(pretrained=True)
57
-
58
  model.classifier = torch.nn.Linear(model.classifier.in_features, len(conditions))
59
- model.eval()
60
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
- model.to(device)
 
62
 
63
- # Load model weights if available
64
- model_path = os.getenv("MODEL_PATH", "xray_model.pth")
65
  if os.path.exists(model_path):
66
  try:
67
- model.load_state_dict(torch.load(model_path, map_location=device))
68
- logger.info("Model loaded from file.")
 
 
 
 
69
  except Exception as e:
70
- logger.warning(f"Failed to load model weights: {e}")
71
  else:
72
- logger.info("No custom model weights found.")
73
 
74
- # Image preprocessing
75
  def preprocess_image(image):
76
  transform = transforms.Compose([
77
  transforms.Resize((224, 224)),
@@ -80,7 +48,7 @@ def preprocess_image(image):
80
  ])
81
  return transform(image).unsqueeze(0).to(device)
82
 
83
- # X-ray prediction function with confidence threshold
84
  def predict_xray(image):
85
  try:
86
  if image is None:
@@ -90,78 +58,34 @@ def predict_xray(image):
90
  with torch.no_grad():
91
  output = model(img_tensor)
92
 
93
- probs = torch.nn.functional.softmax(output, dim=1)[0]
94
- results = {conditions[i]: probs[i].item() * 100 for i in range(len(conditions))}
95
 
96
- top_condition = max(results, key=results.get)
97
- confidence = results[top_condition]
 
 
98
 
99
- if confidence < 50:
100
- return f"""
101
- <div style="font-family:Arial">
102
- <h3>Prediction: <span style="color:#D62828;">Uncertain</span></h3>
103
- <p><b>Confidence:</b> {confidence:.2f}%</p>
104
- <p><b>Note:</b> The model is not confident enough to provide a clear diagnosis.</p>
105
- <p><b>Recommendation:</b> Please consult a radiologist or upload a better-quality image.</p>
106
- </div>
107
- """
108
 
109
- info = condition_details.get(top_condition, condition_details["Other"])
110
- return f"""
111
- <div style="font-family:Arial">
112
- <h3>Prediction: <span style="color:#2A9D8F;">{top_condition}</span></h3>
113
- <p><b>Confidence:</b> {confidence:.2f}%</p>
114
- <p><b>Description:</b> {info['description']}</p>
115
- <p><b>Recommendation:</b> {info['recommendation']}</p>
116
- </div>
117
- """
118
 
119
  except Exception as e:
120
- logger.error(f"Error in prediction: {e}")
121
- return f"Error: {str(e)}"
122
-
123
- # Analyze PDF report
124
- def analyze_report(file):
125
- if not file or not file.name.endswith(".pdf"):
126
- return "Please upload a valid PDF file."
127
- try:
128
- doc = fitz.open(file.name)
129
- text = "".join(page.get_text() for page in doc)
130
- doc.close()
131
-
132
- condition, disease, status = "Unclear", "Unknown", "Pending"
133
 
134
- if "stroke" in text.lower():
135
- condition, disease, status = "Stroke", "Brain Disorder", "Urgent Care Needed"
136
- elif "cancer" in text.lower():
137
- condition, disease, status = "Cancer", "Malignant Growth", "Consult Oncologist"
138
- elif "fracture" in text.lower():
139
- condition, disease, status = "Fracture", "Bone Injury", "Orthopedic Attention Required"
140
-
141
- preview = text[:300] + "..." if text else "No readable content."
142
- return f"Condition: {condition}\nDisease: {disease}\nStatus: {status}\n\nPreview:\n{preview}"
143
-
144
- except Exception as e:
145
- return f"Failed to process PDF: {str(e)}"
146
-
147
- # Gradio interface
148
  def create_interface():
149
  with gr.Blocks() as demo:
150
- gr.Markdown("<h1 style='text-align:center;'>🩻 RadiologyScan AI</h1><p style='text-align:center;'>AI-powered X-ray and Report Analysis</p>")
151
-
152
- with gr.Tabs():
153
- with gr.TabItem("X-ray Analysis"):
154
- img_input = gr.Image(label="Upload Chest X-ray", type="pil")
155
- img_output = gr.HTML()
156
- gr.Button("Analyze X-ray").click(predict_xray, inputs=img_input, outputs=img_output)
157
-
158
- with gr.TabItem("Report Analysis"):
159
- pdf_input = gr.File(label="Upload PDF Report", file_types=[".pdf"])
160
- pdf_output = gr.Textbox(label="Extracted Summary", lines=10)
161
- gr.Button("Analyze Report").click(analyze_report, inputs=pdf_input, outputs=pdf_output)
162
-
163
  return demo
164
 
 
165
  if __name__ == "__main__":
166
  demo = create_interface()
167
  demo.launch(server_port=7860, ssr_mode=False)
 
2
  from PIL import Image
3
  import torch
4
  from torchvision import models, transforms
 
5
  import logging
6
  import os
7
 
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
+ # ChestX-ray14 condition labels
13
  conditions = [
14
+ "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule",
15
+ "Pneumonia", "Pneumothorax", "Consolidation", "Edema", "Emphysema", "Fibrosis",
16
+ "Pleural Thickening", "Hernia"
 
 
 
17
  ]
18
 
19
+ # Load DenseNet121 base model
20
+ model = models.densenet121(pretrained=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  model.classifier = torch.nn.Linear(model.classifier.in_features, len(conditions))
22
+
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ model = model.to(device)
25
+ model.eval()
26
 
27
+ # Load CheXNet pre-trained weights
28
+ model_path = "xray_model.pth"
29
  if os.path.exists(model_path):
30
  try:
31
+ checkpoint = torch.load(model_path, map_location=device)
32
+ if "state_dict" in checkpoint:
33
+ model.load_state_dict(checkpoint["state_dict"])
34
+ else:
35
+ model.load_state_dict(checkpoint)
36
+ logger.info("✅ Loaded CheXNet model weights.")
37
  except Exception as e:
38
+ logger.error(f"Failed to load model weights: {e}")
39
  else:
40
+ logger.warning("⚠️ Model file 'xray_model.pth' not found!")
41
 
42
+ # Preprocessing for image input
43
  def preprocess_image(image):
44
  transform = transforms.Compose([
45
  transforms.Resize((224, 224)),
 
48
  ])
49
  return transform(image).unsqueeze(0).to(device)
50
 
51
+ # X-ray prediction
52
  def predict_xray(image):
53
  try:
54
  if image is None:
 
58
  with torch.no_grad():
59
  output = model(img_tensor)
60
 
61
+ probs = torch.sigmoid(output)[0] # sigmoid for multi-label prediction
62
+ result_lines = []
63
 
64
+ for i, condition in enumerate(conditions):
65
+ confidence = probs[i].item() * 100
66
+ if confidence >= 10: # only show confident predictions
67
+ result_lines.append(f"<b>{condition}:</b> {confidence:.2f}%")
68
 
69
+ if not result_lines:
70
+ return "<b>Uncertain:</b> No strong signs of any known condition detected."
 
 
 
 
 
 
 
71
 
72
+ return "<br>".join(result_lines)
 
 
 
 
 
 
 
 
73
 
74
  except Exception as e:
75
+ logger.error(f"Prediction failed: {e}")
76
+ return f"<b>Error:</b> {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ # Gradio interface setup
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def create_interface():
80
  with gr.Blocks() as demo:
81
+ gr.Markdown("<h1 style='text-align:center;'>🩻 RadiologyScan AI (CheXNet)</h1>")
82
+ with gr.Row():
83
+ image_input = gr.Image(label="Upload Chest X-ray", type="pil")
84
+ result_output = gr.HTML(label="Diagnosis Result")
85
+ gr.Button("Analyze X-ray").click(predict_xray, inputs=image_input, outputs=result_output)
 
 
 
 
 
 
 
 
86
  return demo
87
 
88
+ # Run the app
89
  if __name__ == "__main__":
90
  demo = create_interface()
91
  demo.launch(server_port=7860, ssr_mode=False)