VaneshDev commited on
Commit
0227e86
·
verified ·
1 Parent(s): 4fc8365

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -89
app.py CHANGED
@@ -7,10 +7,10 @@ import logging
7
  import os
8
 
9
  # Set up logging
10
- logging.basicConfig(level=logging.DEBUG)
11
  logger = logging.getLogger(__name__)
12
 
13
- # Define conditions based on provided radiology information
14
  conditions = [
15
  "Normal", "Pneumonia", "Cancer", "TB", "Other",
16
  "Coronary Artery Disease", "Aortic Aneurysm", "Stroke", "Peripheral Artery Disease",
@@ -20,28 +20,71 @@ conditions = [
20
  "Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
21
  ]
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Load and configure the model
24
- model = models.densenet121(weights="IMAGENET1K_V1") # DenseNet pre-trained on ImageNet
25
- num_features = model.classifier.in_features
26
- model.classifier = torch.nn.Linear(num_features, len(conditions)) # Output for all 24 conditions
27
- model.eval()
 
 
 
 
 
 
28
 
29
  # Define device
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  model = model.to(device)
 
32
 
33
- # Load model state if available, otherwise initialize
34
- model_path = "xray_model.pth"
35
  if os.path.exists(model_path):
36
- model.load_state_dict(torch.load(model_path))
37
- logger.info(f"Loaded model from {model_path}")
 
 
 
38
  else:
39
  logger.info("No pre-trained model found. Initializing with random weights. Training required.")
40
 
41
  # Define image preprocessing function
42
  def preprocess_image(image):
 
 
 
 
43
  transform = transforms.Compose([
44
- transforms.Resize((224, 224)), # Resize to fit the model input size
45
  transforms.ToTensor(),
46
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
47
  ])
@@ -49,7 +92,7 @@ def preprocess_image(image):
49
  logger.debug(f"Preprocessed image tensor shape: {image_tensor.shape}")
50
  return image_tensor
51
 
52
- # Define prediction function with detailed output and error handling
53
  def predict_xray(image):
54
  try:
55
  if image is None:
@@ -58,40 +101,21 @@ def predict_xray(image):
58
  image_tensor = preprocess_image(image)
59
  with torch.no_grad():
60
  outputs = model(image_tensor)
61
- probs = torch.nn.functional.softmax(outputs, dim=1)[0] # Softmax over all conditions
62
  results = {conditions[i]: float(probs[i].cpu().numpy()) * 100 for i in range(len(conditions))}
63
 
64
- # Handle case where predicted class might not be in our list of conditions
65
  most_likely_condition = max(results, key=results.get)
66
-
67
- # Ensure the predicted condition is in the valid conditions list
68
- if most_likely_condition not in conditions:
69
  most_likely_condition = "Other"
70
  confidence = 0.0
71
- else:
72
- confidence = results[most_likely_condition]
73
 
74
- # Create a detailed summary of results
75
  summary = f"**Summary**: Based on the X-ray analysis, the most likely diagnosis is: <b>{most_likely_condition}</b> with a confidence of <b>{confidence:.2f}%</b>."
76
-
77
- # Enhanced condition details for each disease/condition
78
- condition_details = {
79
- "Normal": {"description": "No abnormal signs detected.", "recommendation": "Routine check-ups recommended."},
80
- "Pneumonia": {"description": "Lung inflammation detected, possibly infectious.", "recommendation": "Seek medical attention for treatment."},
81
- "Cancer": {"description": "Suspicious masses suggest cancer; further imaging needed.", "recommendation": "Consult an oncologist."},
82
- "TB": {"description": "Cavitary lesions indicate tuberculosis.", "recommendation": "Immediate medical evaluation required."},
83
- "Other": {"description": "Unclear abnormality; further investigation needed.", "recommendation": "Consult a radiologist."},
84
- "Fractures": {"description": "Bone break detected.", "recommendation": "Orthopedic evaluation."},
85
- "COPD": {"description": "Lung damage from COPD observed.", "recommendation": "Pulmonary consultation."},
86
- # Add the rest of the conditions as needed...
87
- }
88
-
89
- # Display results in a clear format
90
  detailed_results = "<ul class='result-list'>"
91
  for condition, prob in results.items():
92
  detailed_results += f"<li><b>{condition}:</b> {prob:.2f}%</li>"
93
  detailed_results += "</ul>"
94
-
95
  additional_feedback = f"<div class='feedback-box'><b>Description:</b> {condition_details[most_likely_condition]['description']}<br><b>Recommendation:</b> {condition_details[most_likely_condition]['recommendation']}</div>"
96
 
97
  logger.info(f"Prediction: {most_likely_condition} with confidence {confidence:.2f}%")
@@ -101,66 +125,74 @@ def predict_xray(image):
101
  logger.error(f"Error in predict_xray: {str(e)}")
102
  return f"Error: {str(e)}", "", ""
103
 
104
- # Enhanced function to analyze patient reports (PDFs) using PyMuPDF (fitz)
105
  def analyze_report(file):
 
 
 
106
  text = ""
107
  patient_condition = "Unclear"
108
  disease = "Unknown"
109
  status = "Pending further tests"
110
 
111
- if file and file.name.endswith(".pdf"):
112
- try:
113
- # Open the PDF using PyMuPDF
114
- pdf_reader = fitz.open(file.name)
115
- for page in pdf_reader:
116
- text += page.get_text("text") # Extract text from each page
117
-
118
- # Example: Let's search for conditions in the text
119
- if "stroke" in text.lower():
120
- patient_condition = "Stroke"
121
- disease = "Brain Disorder"
122
- status = "Urgent Care Needed"
123
- elif "cancer" in text.lower():
124
- patient_condition = "Cancer"
125
- disease = "Malignant Growth"
126
- status = "Consult Oncologist"
127
- elif "fracture" in text.lower():
128
- patient_condition = "Fracture"
129
- disease = "Bone Injury"
130
- status = "Orthopedic Attention Required"
131
- # You can add more conditions here based on keyword matching
132
-
133
- report_summary = f"Patient's Condition: {patient_condition}\nDisease: {disease}\nCondition Status: {status}\n\nReport Preview: {text[:300]}..." if text else "No readable text found in the PDF."
134
-
135
- except Exception as e:
136
- logger.error(f"Error reading PDF: {str(e)}")
137
- report_summary = f"Error processing PDF: {str(e)}"
138
- else:
139
- report_summary = "Please upload a valid PDF file."
140
-
141
- return report_summary
142
-
143
- # Gradio Interface with enhanced UI using Tabs
144
  def create_interface():
145
- with gr.Blocks() as demo:
146
- custom_css = """
147
- .gradio-container { background-color: #f4f6f9; border-radius: 15px; box-shadow: 0 4px 15px rgba(0,0,0,0.1); padding: 30px; font-family: 'Segoe UI', sans-serif; }
148
- .title { font-size: 30px; text-align: center; color: #4C6A92; margin-bottom: 20px; }
149
- .gradio-button { background-color: #3B82F6; color: white; border-radius: 10px; padding: 15px 30px; font-size: 16px; transition: background-color 0.3s; }
150
- .gradio-button:hover { background-color: #2563EB; }
151
- .result-box { background-color: #ffffff; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); margin-top: 20px; max-width: 100%; }
152
- .result-list { padding-left: 20px; margin: 10px 0; }
153
- .result-summary { font-size: 18px; color: #2F4F4F; font-weight: 500; }
154
- .feedback-box { background-color: #F0FFF4; padding: 10px; border-left: 4px solid #38A169; border-radius: 5px; margin-top: 10px; }
155
- """
156
-
157
- gr.Markdown("<h1 class='title'>RadiologyScan AI</h1>")
158
  gr.Markdown("<p style='text-align: center; color: #666;'>AI-powered analysis for X-rays and patient reports</p>")
159
-
160
- # Correctly provide interface_list to TabbedInterface
161
- xray_tab = gr.Interface(fn=predict_xray, inputs=gr.Image(label="Upload X-ray", type="pil"), outputs=[gr.HTML(), gr.HTML(), gr.HTML()])
162
- report_tab = gr.Interface(fn=analyze_report, inputs=gr.File(label="Upload Patient Report (PDF)", file_count="single"), outputs=gr.Textbox(label="Report Summary", interactive=False))
163
-
164
- gr.TabbedInterface([xray_tab, report_tab], tab_names=["X-ray Analysis", "Report Analysis"]).launch(share=True)
165
 
166
- demo = create_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import os
8
 
9
  # Set up logging
10
+ logging.basicConfig(level=logging.DEBUG) # Set to DEBUG for detailed output
11
  logger = logging.getLogger(__name__)
12
 
13
+ # Define conditions
14
  conditions = [
15
  "Normal", "Pneumonia", "Cancer", "TB", "Other",
16
  "Coronary Artery Disease", "Aortic Aneurysm", "Stroke", "Peripheral Artery Disease",
 
20
  "Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
21
  ]
22
 
23
+ # Define condition details for all conditions
24
+ condition_details = {
25
+ "Normal": {"description": "No abnormal signs detected.", "recommendation": "Routine check-ups recommended."},
26
+ "Pneumonia": {"description": "Lung inflammation detected, possibly infectious.", "recommendation": "Seek medical attention for treatment."},
27
+ "Cancer": {"description": "Suspicious masses suggest cancer; further imaging needed.", "recommendation": "Consult an oncologist."},
28
+ "TB": {"description": "Cavitary lesions indicate tuberculosis.", "recommendation": "Immediate medical evaluation required."},
29
+ "Other": {"description": "Unclear abnormality; further investigation needed.", "recommendation": "Consult a radiologist."},
30
+ "Coronary Artery Disease": {"description": "Narrowing of coronary arteries detected.", "recommendation": "Cardiology consultation required."},
31
+ "Aortic Aneurysm": {"description": "Abnormal enlargement of the aorta.", "recommendation": "Vascular surgery evaluation."},
32
+ "Stroke": {"description": "Signs of brain ischemia or hemorrhage.", "recommendation": "Urgent neurological evaluation."},
33
+ "Peripheral Artery Disease": {"description": "Reduced blood flow in peripheral arteries.", "recommendation": "Vascular specialist consultation."},
34
+ "Brain Tumor":>{"description": "Abnormal mass in the brain detected.", "recommendation": "Consult a neurosurgeon."},
35
+ "Alzheimer's Disease": {"description": "Signs of neurodegenerative changes.", "recommendation": "Neurology consultation."},
36
+ "Multiple Sclerosis": {"description": "Demyelinating lesions in the CNS.", "recommendation": "Neurology consultation."},
37
+ "Epilepsy": {"description": "Signs of seizure activity.", "recommendation": "Neurology consultation."},
38
+ "COPD": {"description": "Lung damage from COPD observed.", "recommendation": "Pulmonary consultation."},
39
+ "Lung Cancer": {"description": "Malignant lung masses detected.", "recommendation": "Oncology consultation."},
40
+ "Pulmonary Embolism": {"description": "Blockage in pulmonary arteries.", "recommendation": "Urgent medical attention."},
41
+ "Fractures": {"description": "Bone break detected.", "recommendation": "Orthopedic evaluation."},
42
+ "Arthritis": {"description": "Joint inflammation detected.", "recommendation": "Rheumatology consultation."},
43
+ "Osteoporosis": {"description": "Reduced bone density observed.", "recommendation": "Bone health specialist consultation."},
44
+ "Appendicitis": {"description": "Inflammation of the appendix.", "recommendation": "Surgical evaluation."},
45
+ "Gallstones": {"description": "Stones in the gallbladder detected.", "recommendation": "Gastroenterology consultation."},
46
+ "Kidney Stones": {"description": "Stones in the kidneys detected.", "recommendation": "Urology consultation."},
47
+ "Infections": {"description": "Signs of infection observed.", "recommendation": "Infectious disease consultation."},
48
+ "Abdominal Aortic Aneurysm": {"description": "Enlargement of the abdominal aorta.", "recommendation": "Vascular surgery evaluation."},
49
+ "Diverticulitis": {"description": "Inflammation of diverticula in the colon.", "recommendation": "Gastroenterology consultation."}
50
+ }
51
+
52
  # Load and configure the model
53
+ try:
54
+ model = models.densenet121(weights="IMAGENET1K_V1")
55
+ num_features = model.classifier.in_features
56
+ model.classifier = torch.nn.Linear(num_features, len(conditions))
57
+ model.eval()
58
+ except AttributeError:
59
+ model = models.densenet121(pretrained=True)
60
+ num_features = model.classifier.in_features
61
+ model.classifier = torch.nn.Linear(num_features, len(conditions))
62
+ model.eval()
63
 
64
  # Define device
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
  model = model.to(device)
67
+ logger.info(f"Using device: {device}")
68
 
69
+ # Load model state if available
70
+ model_path = os.getenv("MODEL_PATH", "xray_model.pth")
71
  if os.path.exists(model_path):
72
+ try:
73
+ model.load_state_dict(torch.load(model_path, map_location=device))
74
+ logger.info(f"Loaded model from {model_path}")
75
+ except Exception as e:
76
+ logger.warning(f"Failed to load model from {model_path}: {str(e)}. Using random weights.")
77
  else:
78
  logger.info("No pre-trained model found. Initializing with random weights. Training required.")
79
 
80
  # Define image preprocessing function
81
  def preprocess_image(image):
82
+ if not isinstance(image, Image.Image):
83
+ logger.error("Invalid image format. Expected PIL Image.")
84
+ raise ValueError("Uploaded file is not a valid image.")
85
+
86
  transform = transforms.Compose([
87
+ transforms.Resize((224, 224)),
88
  transforms.ToTensor(),
89
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
90
  ])
 
92
  logger.debug(f"Preprocessed image tensor shape: {image_tensor.shape}")
93
  return image_tensor
94
 
95
+ # Define prediction function
96
  def predict_xray(image):
97
  try:
98
  if image is None:
 
101
  image_tensor = preprocess_image(image)
102
  with torch.no_grad():
103
  outputs = model(image_tensor)
104
+ probs = torch.nn.functional.softmax(outputs, dim=1)[0]
105
  results = {conditions[i]: float(probs[i].cpu().numpy()) * 100 for i in range(len(conditions))}
106
 
 
107
  most_likely_condition = max(results, key=results.get)
108
+ confidence = results[most_likely_condition]
109
+
110
+ if most_likely_condition not in condition_details:
111
  most_likely_condition = "Other"
112
  confidence = 0.0
 
 
113
 
 
114
  summary = f"**Summary**: Based on the X-ray analysis, the most likely diagnosis is: <b>{most_likely_condition}</b> with a confidence of <b>{confidence:.2f}%</b>."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  detailed_results = "<ul class='result-list'>"
116
  for condition, prob in results.items():
117
  detailed_results += f"<li><b>{condition}:</b> {prob:.2f}%</li>"
118
  detailed_results += "</ul>"
 
119
  additional_feedback = f"<div class='feedback-box'><b>Description:</b> {condition_details[most_likely_condition]['description']}<br><b>Recommendation:</b> {condition_details[most_likely_condition]['recommendation']}</div>"
120
 
121
  logger.info(f"Prediction: {most_likely_condition} with confidence {confidence:.2f}%")
 
125
  logger.error(f"Error in predict_xray: {str(e)}")
126
  return f"Error: {str(e)}", "", ""
127
 
128
+ # Enhanced function to analyze patient reports (PDFs)
129
  def analyze_report(file):
130
+ if not file or not file.name.endswith(".pdf"):
131
+ return "Please upload a valid PDF file."
132
+
133
  text = ""
134
  patient_condition = "Unclear"
135
  disease = "Unknown"
136
  status = "Pending further tests"
137
 
138
+ try:
139
+ pdf_reader = fitz.open(file.name)
140
+ for page in pdf_reader:
141
+ text += page.get_text("text")
142
+ pdf_reader.close()
143
+
144
+ if "stroke" in text.lower():
145
+ patient_condition = "Stroke"
146
+ disease = "Brain Disorder"
147
+ status = "Urgent Care Needed"
148
+ elif "cancer" in text.lower():
149
+ patient_condition = "Cancer"
150
+ disease = "Malignant Growth"
151
+ status = "Consult Oncologist"
152
+ elif "fracture" in text.lower():
153
+ patient_condition = "Fracture"
154
+ disease = "Bone Injury"
155
+ status = "Orthopedic Attention Required"
156
+
157
+ report_summary = f"Patient's Condition: {patient_condition}\nDisease: {disease}\nCondition Status: {status}\n\nReport Preview: {text[:300]}..." if text else "No readable text found in the PDF."
158
+ return report_summary
159
+
160
+ except Exception as e:
161
+ logger.error(f"Error reading PDF: {str(e)}")
162
+ return f"Error processing PDF: {str(e)}"
163
+
164
+ # Gradio Interface with Tabs
 
 
 
 
 
 
165
  def create_interface():
166
+ logger.debug("Initializing Gradio interface")
167
+ with gr.Blocks() as demo: # Temporarily remove custom_css
168
+ gr.Markdown("<h1>RadiologyScan AI</h1>")
 
 
 
 
 
 
 
 
 
 
169
  gr.Markdown("<p style='text-align: center; color: #666;'>AI-powered analysis for X-rays and patient reports</p>")
 
 
 
 
 
 
170
 
171
+ with gr.Tabs():
172
+ with gr.TabItem("X-ray Analysis"):
173
+ image_input = gr.Image(label="Upload X-ray", type="pil")
174
+ output_summary = gr.HTML(label="Summary")
175
+ output_details = gr.HTML(label="Detailed Results")
176
+ output_feedback = gr.HTML(label="Additional Feedback")
177
+ gr.Button("Analyze X-ray").click(
178
+ fn=predict_xray,
179
+ inputs=image_input,
180
+ outputs=[output_summary, output_details, output_feedback]
181
+ )
182
+
183
+ with gr.TabItem("Report Analysis"):
184
+ file_input = gr.File(label="Upload Patient Report (PDF)", file_count="single")
185
+ output_report = gr.Textbox(label="Report Summary", interactive=False)
186
+ gr.Button("Analyze Report").click(
187
+ fn=analyze_report,
188
+ inputs=file_input,
189
+ outputs=output_report
190
+ )
191
+ logger.debug("Gradio interface initialized")
192
+ return demo
193
+
194
+ if __name__ == "__main__":
195
+ logger.debug("Starting Gradio application")
196
+ demo = create_interface()
197
+ demo.launch(server_port=7860, ssr_mode=False) # Explicit port, disable SSR
198
+ logger.debug("Gradio application launched")