VaneshDev commited on
Commit
897c5ee
·
verified ·
1 Parent(s): 085f757

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -111
app.py CHANGED
@@ -1,16 +1,16 @@
1
  import gradio as gr
 
2
  import torch
3
  from torchvision import models, transforms
4
- import os
5
- import time
6
- import logging
7
  import fitz # PyMuPDF for better PDF parsing
 
 
8
 
9
- # Set up logging (optional for debugging)
10
- logging.basicConfig(level=logging.DEBUG)
11
  logger = logging.getLogger(__name__)
12
 
13
- # Define conditions based on provided radiology information
14
  conditions = [
15
  "Normal", "Pneumonia", "Cancer", "TB", "Other",
16
  "Coronary Artery Disease", "Aortic Aneurysm", "Stroke", "Peripheral Artery Disease",
@@ -20,38 +20,71 @@ conditions = [
20
  "Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
21
  ]
22
 
23
- # Define path to store the model manually
24
- model_path = "/home/user/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth" # Adjusted to a valid path
25
-
26
- # Ensure the parent directory exists
27
- parent_dir = os.path.dirname(model_path)
28
- if not os.path.exists(parent_dir):
29
- os.makedirs(parent_dir) # Create the parent directory if it doesn't exist
30
-
31
- # Function to load the model efficiently
32
- def load_model():
33
- if os.path.exists(model_path):
34
- model = models.densenet121()
35
- model.load_state_dict(torch.load(model_path)) # Load from cached path
36
- model.eval() # Set to evaluation mode
37
- logger.info("Loaded model from cache.")
38
- else:
39
- model = models.densenet121(weights="IMAGENET1K_V1") # If not cached, download model
40
- torch.save(model.state_dict(), model_path) # Cache the model locally
41
- logger.info("Downloaded and cached the model.")
42
- return model
43
-
44
- # Load the model at the beginning (this will take time but only happens once)
45
- model = load_model()
46
-
47
- # Define device for model inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
  model = model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Define image preprocessing function
52
  def preprocess_image(image):
 
 
 
 
53
  transform = transforms.Compose([
54
- transforms.Resize((224, 224)), # Resize to fit the model input size
55
  transforms.ToTensor(),
56
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
57
  ])
@@ -59,7 +92,7 @@ def preprocess_image(image):
59
  logger.debug(f"Preprocessed image tensor shape: {image_tensor.shape}")
60
  return image_tensor
61
 
62
- # Define prediction function with detailed output and error handling
63
  def predict_xray(image):
64
  try:
65
  if image is None:
@@ -68,40 +101,21 @@ def predict_xray(image):
68
  image_tensor = preprocess_image(image)
69
  with torch.no_grad():
70
  outputs = model(image_tensor)
71
- probs = torch.nn.functional.softmax(outputs, dim=1)[0] # Softmax over all conditions
72
  results = {conditions[i]: float(probs[i].cpu().numpy()) * 100 for i in range(len(conditions))}
73
 
74
- # Handle case where predicted class might not be in our list of conditions
75
  most_likely_condition = max(results, key=results.get)
76
-
77
- # Ensure the predicted condition is in the valid conditions list
78
- if most_likely_condition not in conditions:
79
  most_likely_condition = "Other"
80
  confidence = 0.0
81
- else:
82
- confidence = results[most_likely_condition]
83
 
84
- # Create a detailed summary of results
85
  summary = f"**Summary**: Based on the X-ray analysis, the most likely diagnosis is: <b>{most_likely_condition}</b> with a confidence of <b>{confidence:.2f}%</b>."
86
-
87
- # Enhanced condition details for each disease/condition
88
- condition_details = {
89
- "Normal": {"description": "No abnormal signs detected.", "recommendation": "Routine check-ups recommended."},
90
- "Pneumonia": {"description": "Lung inflammation detected, possibly infectious.", "recommendation": "Seek medical attention for treatment."},
91
- "Cancer": {"description": "Suspicious masses suggest cancer; further imaging needed.", "recommendation": "Consult an oncologist."},
92
- "TB": {"description": "Cavitary lesions indicate tuberculosis.", "recommendation": "Immediate medical evaluation required."},
93
- "Other": {"description": "Unclear abnormality; further investigation needed.", "recommendation": "Consult a radiologist."},
94
- "Fractures": {"description": "Bone break detected.", "recommendation": "Orthopedic evaluation."},
95
- "COPD": {"description": "Lung damage from COPD observed.", "recommendation": "Pulmonary consultation."},
96
- # Add the rest of the conditions as needed...
97
- }
98
-
99
- # Display results in a clear format
100
  detailed_results = "<ul class='result-list'>"
101
  for condition, prob in results.items():
102
  detailed_results += f"<li><b>{condition}:</b> {prob:.2f}%</li>"
103
  detailed_results += "</ul>"
104
-
105
  additional_feedback = f"<div class='feedback-box'><b>Description:</b> {condition_details[most_likely_condition]['description']}<br><b>Recommendation:</b> {condition_details[most_likely_condition]['recommendation']}</div>"
106
 
107
  logger.info(f"Prediction: {most_likely_condition} with confidence {confidence:.2f}%")
@@ -111,66 +125,82 @@ def predict_xray(image):
111
  logger.error(f"Error in predict_xray: {str(e)}")
112
  return f"Error: {str(e)}", "", ""
113
 
114
- # Enhanced function to analyze patient reports (PDFs) using PyMuPDF (fitz)
115
  def analyze_report(file):
 
 
 
116
  text = ""
117
  patient_condition = "Unclear"
118
  disease = "Unknown"
119
  status = "Pending further tests"
120
 
121
- if file and file.name.endswith(".pdf"):
122
- try:
123
- # Open the PDF using PyMuPDF
124
- pdf_reader = fitz.open(file.name)
125
- for page in pdf_reader:
126
- text += page.get_text("text") # Extract text from each page
127
-
128
- # Example: Let's search for conditions in the text
129
- if "stroke" in text.lower():
130
- patient_condition = "Stroke"
131
- disease = "Brain Disorder"
132
- status = "Urgent Care Needed"
133
- elif "cancer" in text.lower():
134
- patient_condition = "Cancer"
135
- disease = "Malignant Growth"
136
- status = "Consult Oncologist"
137
- elif "fracture" in text.lower():
138
- patient_condition = "Fracture"
139
- disease = "Bone Injury"
140
- status = "Orthopedic Attention Required"
141
- # You can add more conditions here based on keyword matching
142
-
143
- report_summary = f"Patient's Condition: {patient_condition}\nDisease: {disease}\nCondition Status: {status}\n\nReport Preview: {text[:300]}..." if text else "No readable text found in the PDF."
144
-
145
- except Exception as e:
146
- logger.error(f"Error reading PDF: {str(e)}")
147
- report_summary = f"Error processing PDF: {str(e)}"
148
- else:
149
- report_summary = "Please upload a valid PDF file."
150
-
151
- return report_summary
152
-
153
- # Gradio Interface with enhanced UI using Tabs
154
- def create_interface():
155
- with gr.Blocks() as demo:
156
- custom_css = """
157
- .gradio-container { background-color: #f4f6f9; border-radius: 15px; box-shadow: 0 4px 15px rgba(0,0,0,0.1); padding: 30px; font-family: 'Segoe UI', sans-serif; }
158
- .title { font-size: 30px; text-align: center; color: #4C6A92; margin-bottom: 20px; }
159
- .gradio-button { background-color: #3B82F6; color: white; border-radius: 10px; padding: 15px 30px; font-size: 16px; transition: background-color 0.3s; }
160
- .gradio-button:hover { background-color: #2563EB; }
161
- .result-box { background-color: #ffffff; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); margin-top: 20px; max-width: 100%; }
162
- .result-list { padding-left: 20px; margin: 10px 0; }
163
- .result-summary { font-size: 18px; color: #2F4F4F; font-weight: 500; }
164
- .feedback-box { background-color: #F0FFF4; padding: 10px; border-left: 4px solid #38A169; border-radius: 5px; margin-top: 10px; }
165
- """
166
-
167
- gr.Markdown("<h1 class='title'>RadiologyScan AI</h1>")
168
- gr.Markdown("<p style='text-align: center; color: #666;'>AI-powered analysis for X-rays and patient reports</p>")
169
-
170
- # Correctly provide interface_list to TabbedInterface
171
- xray_tab = gr.Interface(fn=predict_xray, inputs=gr.Image(label="Upload X-ray", type="pil"), outputs=[gr.HTML(), gr.HTML(), gr.HTML()])
172
- report_tab = gr.Interface(fn=analyze_report, inputs=gr.File(label="Upload Patient Report (PDF)", file_count="single"), outputs=gr.Textbox(label="Report Summary", interactive=False))
173
 
174
- gr.TabbedInterface([xray_tab, report_tab], tab_names=["X-ray Analysis", "Report Analysis"]).launch(share=False)
 
 
175
 
176
- demo = create_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from PIL import Image
3
  import torch
4
  from torchvision import models, transforms
 
 
 
5
  import fitz # PyMuPDF for better PDF parsing
6
+ import logging
7
+ import os
8
 
9
+ # Set up logging
10
+ logging.basicConfig(level=logging.DEBUG) # Set to DEBUG for detailed output
11
  logger = logging.getLogger(__name__)
12
 
13
+ # Define conditions
14
  conditions = [
15
  "Normal", "Pneumonia", "Cancer", "TB", "Other",
16
  "Coronary Artery Disease", "Aortic Aneurysm", "Stroke", "Peripheral Artery Disease",
 
20
  "Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
21
  ]
22
 
23
+ # Define condition details for all conditions
24
+ condition_details = {
25
+ "Normal": {"description": "No abnormal signs detected.", "recommendation": "Routine check-ups recommended."},
26
+ "Pneumonia": {"description": "Lung inflammation detected, possibly infectious.", "recommendation": "Seek medical attention for treatment."},
27
+ "Cancer": {"description": "Suspicious masses suggest cancer; further imaging needed.", "recommendation": "Consult an oncologist."},
28
+ "TB": {"description": "Cavitary lesions indicate tuberculosis.", "recommendation": "Immediate medical evaluation required."},
29
+ "Other": {"description": "Unclear abnormality; further investigation needed.", "recommendation": "Consult a radiologist."},
30
+ "Coronary Artery Disease": {"description": "Narrowing of coronary arteries detected.", "recommendation": "Cardiology consultation required."},
31
+ "Aortic Aneurysm": {"description": "Abnormal enlargement of the aorta.", "recommendation": "Vascular surgery evaluation."},
32
+ "Stroke": {"description": "Signs of brain ischemia or hemorrhage.", "recommendation": "Urgent neurological evaluation."},
33
+ "Peripheral Artery Disease": {"description": "Reduced blood flow in peripheral arteries.", "recommendation": "Vascular specialist consultation."},
34
+ "Brain Tumor": {"description": "Abnormal mass in the brain detected.", "recommendation": "Consult a neurosurgeon."},
35
+ "Alzheimer's Disease": {"description": "Signs of neurodegenerative changes.", "recommendation": "Neurology consultation."},
36
+ "Multiple Sclerosis": {"description": "Demyelinating lesions in the CNS.", "recommendation": "Neurology consultation."},
37
+ "Epilepsy": {"description": "Signs of seizure activity.", "recommendation": "Neurology consultation."},
38
+ "COPD": {"description": "Lung damage from COPD observed.", "recommendation": "Pulmonary consultation."},
39
+ "Lung Cancer": {"description": "Malignant lung masses detected.", "recommendation": "Oncology consultation."},
40
+ "Pulmonary Embolism": {"description": "Blockage in pulmonary arteries.", "recommendation": "Urgent medical attention."},
41
+ "Fractures": {"description": "Bone break detected.", "recommendation": "Orthopedic evaluation."},
42
+ "Arthritis": {"description": "Joint inflammation detected.", "recommendation": "Rheumatology consultation."},
43
+ "Osteoporosis": {"description": "Reduced bone density observed.", "recommendation": "Bone health specialist consultation."},
44
+ "Appendicitis": {"description": "Inflammation of the appendix.", "recommendation": "Surgical evaluation."},
45
+ "Gallstones": {"description": "Stones in the gallbladder detected.", "recommendation": "Gastroenterology consultation."},
46
+ "Kidney Stones": {"description": "Stones in the kidneys detected.", "recommendation": "Urology consultation."},
47
+ "Infections": {"description": "Signs of infection observed.", "recommendation": "Infectious disease consultation."},
48
+ "Abdominal Aortic Aneurysm": {"description": "Enlargement of the abdominal aorta.", "recommendation": "Vascular surgery evaluation."},
49
+ "Diverticulitis": {"description": "Inflammation of diverticula in the colon.", "recommendation": "Gastroenterology consultation."}
50
+ }
51
+
52
+ # Load and configure the model
53
+ try:
54
+ model = models.densenet121(weights="IMAGENET1K_V1")
55
+ num_features = model.classifier.in_features
56
+ model.classifier = torch.nn.Linear(num_features, len(conditions))
57
+ model.eval()
58
+ except AttributeError:
59
+ model = models.densenet121(pretrained=True)
60
+ num_features = model.classifier.in_features
61
+ model.classifier = torch.nn.Linear(num_features, len(conditions))
62
+ model.eval()
63
+
64
+ # Define device
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
  model = model.to(device)
67
+ logger.info(f"Using device: {device}")
68
+
69
+ # Load model state if available
70
+ model_path = os.getenv("MODEL_PATH", "xray_model.pth")
71
+ if os.path.exists(model_path):
72
+ try:
73
+ model.load_state_dict(torch.load(model_path, map_location=device))
74
+ logger.info(f"Loaded model from {model_path}")
75
+ except Exception as e:
76
+ logger.warning(f"Failed to load model from {model_path}: {str(e)}. Using random weights.")
77
+ else:
78
+ logger.info("No pre-trained model found. Initializing with random weights. Training required.")
79
 
80
  # Define image preprocessing function
81
  def preprocess_image(image):
82
+ if not isinstance(image, Image.Image):
83
+ logger.error("Invalid image format. Expected PIL Image.")
84
+ raise ValueError("Uploaded file is not a valid image.")
85
+
86
  transform = transforms.Compose([
87
+ transforms.Resize((224, 224)),
88
  transforms.ToTensor(),
89
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
90
  ])
 
92
  logger.debug(f"Preprocessed image tensor shape: {image_tensor.shape}")
93
  return image_tensor
94
 
95
+ # Define prediction function
96
  def predict_xray(image):
97
  try:
98
  if image is None:
 
101
  image_tensor = preprocess_image(image)
102
  with torch.no_grad():
103
  outputs = model(image_tensor)
104
+ probs = torch.nn.functional.softmax(outputs, dim=1)[0]
105
  results = {conditions[i]: float(probs[i].cpu().numpy()) * 100 for i in range(len(conditions))}
106
 
 
107
  most_likely_condition = max(results, key=results.get)
108
+ confidence = results[most_likely_condition]
109
+
110
+ if most_likely_condition not in condition_details:
111
  most_likely_condition = "Other"
112
  confidence = 0.0
 
 
113
 
 
114
  summary = f"**Summary**: Based on the X-ray analysis, the most likely diagnosis is: <b>{most_likely_condition}</b> with a confidence of <b>{confidence:.2f}%</b>."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  detailed_results = "<ul class='result-list'>"
116
  for condition, prob in results.items():
117
  detailed_results += f"<li><b>{condition}:</b> {prob:.2f}%</li>"
118
  detailed_results += "</ul>"
 
119
  additional_feedback = f"<div class='feedback-box'><b>Description:</b> {condition_details[most_likely_condition]['description']}<br><b>Recommendation:</b> {condition_details[most_likely_condition]['recommendation']}</div>"
120
 
121
  logger.info(f"Prediction: {most_likely_condition} with confidence {confidence:.2f}%")
 
125
  logger.error(f"Error in predict_xray: {str(e)}")
126
  return f"Error: {str(e)}", "", ""
127
 
128
+ # Enhanced function to analyze patient reports (PDFs)
129
  def analyze_report(file):
130
+ if not file or not file.name.endswith(".pdf"):
131
+ return "Please upload a valid PDF file."
132
+
133
  text = ""
134
  patient_condition = "Unclear"
135
  disease = "Unknown"
136
  status = "Pending further tests"
137
 
138
+ try:
139
+ pdf_reader = fitz.open(file.name)
140
+ for page in pdf_reader:
141
+ text += page.get_text("text")
142
+ pdf_reader.close()
143
+
144
+ if "stroke" in text.lower():
145
+ patient_condition = "Stroke"
146
+ disease = "Brain Disorder"
147
+ status = "Urgent Care Needed"
148
+ elif "cancer" in text.lower():
149
+ patient_condition = "Cancer"
150
+ disease = "Malignant Growth"
151
+ status = "Consult Oncologist"
152
+ elif "fracture" in text.lower():
153
+ patient_condition = "Fracture"
154
+ disease = "Bone Injury"
155
+ status = "Orthopedic Attention Required"
156
+
157
+ report_summary = f"Patient's Condition: {patient_condition}\nDisease: {disease}\nCondition Status: {status}\n\nReport Preview: {text[:300]}..." if text else "No readable text found in the PDF."
158
+ return report_summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ except Exception as e:
161
+ logger.error(f"Error reading PDF: {str(e)}")
162
+ return f"Error processing PDF: {str(e)}"
163
 
164
+ # Gradio Interface with Tabs
165
+ def create_interface():
166
+ logger.debug("Initializing Gradio interface")
167
+ # Minimal CSS to avoid styling issues
168
+ custom_css = """
169
+ .title { font-size: 30px; text-align: center; color: #4C6A92; }
170
+ .subtitle { text-align: center; color: #666; }
171
+ """
172
+ with gr.Blocks(css=custom_css) as demo:
173
+ gr.Markdown("<h1 class='title'>RadiologyScan AI</h1>")
174
+ gr.Markdown("<p class='subtitle'>AI-powered analysis for X-rays and patient reports</p>")
175
+
176
+ with gr.Tabs():
177
+ with gr.TabItem("X-ray Analysis"):
178
+ image_input = gr.Image(label="Upload X-ray", type="pil")
179
+ output_summary = gr.HTML(label="Summary")
180
+ output_details = gr.HTML(label="Detailed Results")
181
+ output_feedback = gr.HTML(label="Additional Feedback")
182
+ gr.Button("Analyze X-ray").click(
183
+ fn=predict_xray,
184
+ inputs=image_input,
185
+ outputs=[output_summary, output_details, output_feedback]
186
+ )
187
+
188
+ with gr.TabItem("Report Analysis"):
189
+ file_input = gr.File(label="Upload Patient Report (PDF)", file_count="single")
190
+ output_report = gr.Textbox(label="Report Summary", interactive=False)
191
+ gr.Button("Analyze Report").click(
192
+ fn=analyze_report,
193
+ inputs=file_input,
194
+ outputs=output_report
195
+ )
196
+ logger.debug("Gradio interface initialized")
197
+ return demo
198
+
199
+ if __name__ == "__main__":
200
+ logger.debug("Starting Gradio application")
201
+ try:
202
+ demo = create_interface()
203
+ demo.launch(server_port=7860, ssr_mode=False) # Explicit port, disable SSR
204
+ logger.debug("Gradio application launched")
205
+ except Exception as e:
206
+ logger.error(f"Failed to launch Gradio application: {str(e)}")