VaneshDev commited on
Commit
7303d3d
·
verified ·
1 Parent(s): ef9c28d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -130
app.py CHANGED
@@ -7,10 +7,10 @@ import logging
7
  import os
8
 
9
  # Set up logging
10
- logging.basicConfig(level=logging.DEBUG) # Set to DEBUG for detailed output
11
  logger = logging.getLogger(__name__)
12
 
13
- # Define conditions
14
  conditions = [
15
  "Normal", "Pneumonia", "Cancer", "TB", "Other",
16
  "Coronary Artery Disease", "Aortic Aneurysm", "Stroke", "Peripheral Artery Disease",
@@ -20,71 +20,28 @@ conditions = [
20
  "Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
21
  ]
22
 
23
- # Define condition details for all conditions
24
- condition_details = {
25
- "Normal": {"description": "No abnormal signs detected.", "recommendation": "Routine check-ups recommended."},
26
- "Pneumonia": {"description": "Lung inflammation detected, possibly infectious.", "recommendation": "Seek medical attention for treatment."},
27
- "Cancer": {"description": "Suspicious masses suggest cancer; further imaging needed.", "recommendation": "Consult an oncologist."},
28
- "TB": {"description": "Cavitary lesions indicate tuberculosis.", "recommendation": "Immediate medical evaluation required."},
29
- "Other": {"description": "Unclear abnormality; further investigation needed.", "recommendation": "Consult a radiologist."},
30
- "Coronary Artery Disease": {"description": "Narrowing of coronary arteries detected.", "recommendation": "Cardiology consultation required."},
31
- "Aortic Aneurysm": {"description": "Abnormal enlargement of the aorta.", "recommendation": "Vascular surgery evaluation."},
32
- "Stroke": {"description": "Signs of brain ischemia or hemorrhage.", "recommendation": "Urgent neurological evaluation."},
33
- "Peripheral Artery Disease": {"description": "Reduced blood flow in peripheral arteries.", "recommendation": "Vascular specialist consultation."},
34
- "Brain Tumor": {"description": "Abnormal mass in the brain detected.", "recommendation": "Consult a neurosurgeon."},
35
- "Alzheimer's Disease": {"description": "Signs of neurodegenerative changes.", "recommendation": "Neurology consultation."},
36
- "Multiple Sclerosis": {"description": "Demyelinating lesions in the CNS.", "recommendation": "Neurology consultation."},
37
- "Epilepsy": {"description": "Signs of seizure activity.", "recommendation": "Neurology consultation."},
38
- "COPD": {"description": "Lung damage from COPD observed.", "recommendation": "Pulmonary consultation."},
39
- "Lung Cancer": {"description": "Malignant lung masses detected.", "recommendation": "Oncology consultation."},
40
- "Pulmonary Embolism": {"description": "Blockage in pulmonary arteries.", "recommendation": "Urgent medical attention."},
41
- "Fractures": {"description": "Bone break detected.", "recommendation": "Orthopedic evaluation."},
42
- "Arthritis": {"description": "Joint inflammation detected.", "recommendation": "Rheumatology consultation."},
43
- "Osteoporosis": {"description": "Reduced bone density observed.", "recommendation": "Bone health specialist consultation."},
44
- "Appendicitis": {"description": "Inflammation of the appendix.", "recommendation": "Surgical evaluation."},
45
- "Gallstones": {"description": "Stones in the gallbladder detected.", "recommendation": "Gastroenterology consultation."},
46
- "Kidney Stones": {"description": "Stones in the kidneys detected.", "recommendation": "Urology consultation."},
47
- "Infections": {"description": "Signs of infection observed.", "recommendation": "Infectious disease consultation."},
48
- "Abdominal Aortic Aneurysm": {"description": "Enlargement of the abdominal aorta.", "recommendation": "Vascular surgery evaluation."},
49
- "Diverticulitis": {"description": "Inflammation of diverticula in the colon.", "recommendation": "Gastroenterology consultation."}
50
- }
51
-
52
  # Load and configure the model
53
- try:
54
- model = models.densenet121(weights="IMAGENET1K_V1")
55
- num_features = model.classifier.in_features
56
- model.classifier = torch.nn.Linear(num_features, len(conditions))
57
- model.eval()
58
- except AttributeError:
59
- model = models.densenet121(pretrained=True)
60
- num_features = model.classifier.in_features
61
- model.classifier = torch.nn.Linear(num_features, len(conditions))
62
- model.eval()
63
 
64
  # Define device
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
  model = model.to(device)
67
- logger.info(f"Using device: {device}")
68
 
69
- # Load model state if available
70
- model_path = os.getenv("MODEL_PATH", "xray_model.pth")
71
  if os.path.exists(model_path):
72
- try:
73
- model.load_state_dict(torch.load(model_path, map_location=device))
74
- logger.info(f"Loaded model from {model_path}")
75
- except Exception as e:
76
- logger.warning(f"Failed to load model from {model_path}: {str(e)}. Using random weights.")
77
  else:
78
  logger.info("No pre-trained model found. Initializing with random weights. Training required.")
79
 
80
  # Define image preprocessing function
81
  def preprocess_image(image):
82
- if not isinstance(image, Image.Image):
83
- logger.error("Invalid image format. Expected PIL Image.")
84
- raise ValueError("Uploaded file is not a valid image.")
85
-
86
  transform = transforms.Compose([
87
- transforms.Resize((224, 224)),
88
  transforms.ToTensor(),
89
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
90
  ])
@@ -92,7 +49,7 @@ def preprocess_image(image):
92
  logger.debug(f"Preprocessed image tensor shape: {image_tensor.shape}")
93
  return image_tensor
94
 
95
- # Define prediction function
96
  def predict_xray(image):
97
  try:
98
  if image is None:
@@ -101,21 +58,40 @@ def predict_xray(image):
101
  image_tensor = preprocess_image(image)
102
  with torch.no_grad():
103
  outputs = model(image_tensor)
104
- probs = torch.nn.functional.softmax(outputs, dim=1)[0]
105
  results = {conditions[i]: float(probs[i].cpu().numpy()) * 100 for i in range(len(conditions))}
106
 
 
107
  most_likely_condition = max(results, key=results.get)
108
- confidence = results[most_likely_condition]
109
-
110
- if most_likely_condition not in condition_details:
111
  most_likely_condition = "Other"
112
  confidence = 0.0
 
 
113
 
 
114
  summary = f"**Summary**: Based on the X-ray analysis, the most likely diagnosis is: <b>{most_likely_condition}</b> with a confidence of <b>{confidence:.2f}%</b>."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  detailed_results = "<ul class='result-list'>"
116
  for condition, prob in results.items():
117
  detailed_results += f"<li><b>{condition}:</b> {prob:.2f}%</li>"
118
  detailed_results += "</ul>"
 
119
  additional_feedback = f"<div class='feedback-box'><b>Description:</b> {condition_details[most_likely_condition]['description']}<br><b>Recommendation:</b> {condition_details[most_likely_condition]['recommendation']}</div>"
120
 
121
  logger.info(f"Prediction: {most_likely_condition} with confidence {confidence:.2f}%")
@@ -125,82 +101,66 @@ def predict_xray(image):
125
  logger.error(f"Error in predict_xray: {str(e)}")
126
  return f"Error: {str(e)}", "", ""
127
 
128
- # Enhanced function to analyze patient reports (PDFs)
129
  def analyze_report(file):
130
- if not file or not file.name.endswith(".pdf"):
131
- return "Please upload a valid PDF file."
132
-
133
  text = ""
134
  patient_condition = "Unclear"
135
  disease = "Unknown"
136
  status = "Pending further tests"
137
 
138
- try:
139
- pdf_reader = fitz.open(file.name)
140
- for page in pdf_reader:
141
- text += page.get_text("text")
142
- pdf_reader.close()
143
-
144
- if "stroke" in text.lower():
145
- patient_condition = "Stroke"
146
- disease = "Brain Disorder"
147
- status = "Urgent Care Needed"
148
- elif "cancer" in text.lower():
149
- patient_condition = "Cancer"
150
- disease = "Malignant Growth"
151
- status = "Consult Oncologist"
152
- elif "fracture" in text.lower():
153
- patient_condition = "Fracture"
154
- disease = "Bone Injury"
155
- status = "Orthopedic Attention Required"
156
-
157
- report_summary = f"Patient's Condition: {patient_condition}\nDisease: {disease}\nCondition Status: {status}\n\nReport Preview: {text[:300]}..." if text else "No readable text found in the PDF."
158
- return report_summary
159
-
160
- except Exception as e:
161
- logger.error(f"Error reading PDF: {str(e)}")
162
- return f"Error processing PDF: {str(e)}"
163
-
164
- # Gradio Interface with Tabs
 
 
 
 
 
 
165
  def create_interface():
166
- logger.debug("Initializing Gradio interface")
167
- # Minimal CSS to avoid styling issues
168
- custom_css = """
169
- .title { font-size: 30px; text-align: center; color: #4C6A92; }
170
- .subtitle { text-align: center; color: #666; }
171
- """
172
- with gr.Blocks(css=custom_css) as demo:
 
 
 
 
 
173
  gr.Markdown("<h1 class='title'>RadiologyScan AI</h1>")
174
- gr.Markdown("<p class='subtitle'>AI-powered analysis for X-rays and patient reports</p>")
175
-
176
- with gr.Tabs():
177
- with gr.TabItem("X-ray Analysis"):
178
- image_input = gr.Image(label="Upload X-ray", type="pil")
179
- output_summary = gr.HTML(label="Summary")
180
- output_details = gr.HTML(label="Detailed Results")
181
- output_feedback = gr.HTML(label="Additional Feedback")
182
- gr.Button("Analyze X-ray").click(
183
- fn=predict_xray,
184
- inputs=image_input,
185
- outputs=[output_summary, output_details, output_feedback]
186
- )
187
-
188
- with gr.TabItem("Report Analysis"):
189
- file_input = gr.File(label="Upload Patient Report (PDF)", file_count="single")
190
- output_report = gr.Textbox(label="Report Summary", interactive=False)
191
- gr.Button("Analyze Report").click(
192
- fn=analyze_report,
193
- inputs=file_input,
194
- outputs=output_report
195
- )
196
- logger.debug("Gradio interface initialized")
197
- return demo
198
-
199
- if __name__ == "__main__":
200
- logger.debug("Starting Gradio application")
201
- try:
202
- demo = create_interface()
203
- demo.launch(server_port=7860, ssr_mode=False) # Explicit port, disable SSR
204
- logger.debug("Gradio application launched")
205
- except Exception as e:
206
- logger.error(f"Failed to launch Gradio application: {str(e)}")
 
7
  import os
8
 
9
  # Set up logging
10
+ logging.basicConfig(level=logging.DEBUG)
11
  logger = logging.getLogger(__name__)
12
 
13
+ # Define conditions based on provided radiology information
14
  conditions = [
15
  "Normal", "Pneumonia", "Cancer", "TB", "Other",
16
  "Coronary Artery Disease", "Aortic Aneurysm", "Stroke", "Peripheral Artery Disease",
 
20
  "Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
21
  ]
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Load and configure the model
24
+ model = models.densenet121(weights="IMAGENET1K_V1") # DenseNet pre-trained on ImageNet
25
+ num_features = model.classifier.in_features
26
+ model.classifier = torch.nn.Linear(num_features, len(conditions)) # Output for all 24 conditions
27
+ model.eval()
 
 
 
 
 
 
28
 
29
  # Define device
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  model = model.to(device)
 
32
 
33
+ # Load model state if available, otherwise initialize
34
+ model_path = "xray_model.pth"
35
  if os.path.exists(model_path):
36
+ model.load_state_dict(torch.load(model_path))
37
+ logger.info(f"Loaded model from {model_path}")
 
 
 
38
  else:
39
  logger.info("No pre-trained model found. Initializing with random weights. Training required.")
40
 
41
  # Define image preprocessing function
42
  def preprocess_image(image):
 
 
 
 
43
  transform = transforms.Compose([
44
+ transforms.Resize((224, 224)), # Resize to fit the model input size
45
  transforms.ToTensor(),
46
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
47
  ])
 
49
  logger.debug(f"Preprocessed image tensor shape: {image_tensor.shape}")
50
  return image_tensor
51
 
52
+ # Define prediction function with detailed output and error handling
53
  def predict_xray(image):
54
  try:
55
  if image is None:
 
58
  image_tensor = preprocess_image(image)
59
  with torch.no_grad():
60
  outputs = model(image_tensor)
61
+ probs = torch.nn.functional.softmax(outputs, dim=1)[0] # Softmax over all conditions
62
  results = {conditions[i]: float(probs[i].cpu().numpy()) * 100 for i in range(len(conditions))}
63
 
64
+ # Handle case where predicted class might not be in our list of conditions
65
  most_likely_condition = max(results, key=results.get)
66
+
67
+ # Ensure the predicted condition is in the valid conditions list
68
+ if most_likely_condition not in conditions:
69
  most_likely_condition = "Other"
70
  confidence = 0.0
71
+ else:
72
+ confidence = results[most_likely_condition]
73
 
74
+ # Create a detailed summary of results
75
  summary = f"**Summary**: Based on the X-ray analysis, the most likely diagnosis is: <b>{most_likely_condition}</b> with a confidence of <b>{confidence:.2f}%</b>."
76
+
77
+ # Enhanced condition details for each disease/condition
78
+ condition_details = {
79
+ "Normal": {"description": "No abnormal signs detected.", "recommendation": "Routine check-ups recommended."},
80
+ "Pneumonia": {"description": "Lung inflammation detected, possibly infectious.", "recommendation": "Seek medical attention for treatment."},
81
+ "Cancer": {"description": "Suspicious masses suggest cancer; further imaging needed.", "recommendation": "Consult an oncologist."},
82
+ "TB": {"description": "Cavitary lesions indicate tuberculosis.", "recommendation": "Immediate medical evaluation required."},
83
+ "Other": {"description": "Unclear abnormality; further investigation needed.", "recommendation": "Consult a radiologist."},
84
+ "Fractures": {"description": "Bone break detected.", "recommendation": "Orthopedic evaluation."},
85
+ "COPD": {"description": "Lung damage from COPD observed.", "recommendation": "Pulmonary consultation."},
86
+ # Add the rest of the conditions as needed...
87
+ }
88
+
89
+ # Display results in a clear format
90
  detailed_results = "<ul class='result-list'>"
91
  for condition, prob in results.items():
92
  detailed_results += f"<li><b>{condition}:</b> {prob:.2f}%</li>"
93
  detailed_results += "</ul>"
94
+
95
  additional_feedback = f"<div class='feedback-box'><b>Description:</b> {condition_details[most_likely_condition]['description']}<br><b>Recommendation:</b> {condition_details[most_likely_condition]['recommendation']}</div>"
96
 
97
  logger.info(f"Prediction: {most_likely_condition} with confidence {confidence:.2f}%")
 
101
  logger.error(f"Error in predict_xray: {str(e)}")
102
  return f"Error: {str(e)}", "", ""
103
 
104
+ # Enhanced function to analyze patient reports (PDFs) using PyMuPDF (fitz)
105
  def analyze_report(file):
 
 
 
106
  text = ""
107
  patient_condition = "Unclear"
108
  disease = "Unknown"
109
  status = "Pending further tests"
110
 
111
+ if file and file.name.endswith(".pdf"):
112
+ try:
113
+ # Open the PDF using PyMuPDF
114
+ pdf_reader = fitz.open(file.name)
115
+ for page in pdf_reader:
116
+ text += page.get_text("text") # Extract text from each page
117
+
118
+ # Example: Let's search for conditions in the text
119
+ if "stroke" in text.lower():
120
+ patient_condition = "Stroke"
121
+ disease = "Brain Disorder"
122
+ status = "Urgent Care Needed"
123
+ elif "cancer" in text.lower():
124
+ patient_condition = "Cancer"
125
+ disease = "Malignant Growth"
126
+ status = "Consult Oncologist"
127
+ elif "fracture" in text.lower():
128
+ patient_condition = "Fracture"
129
+ disease = "Bone Injury"
130
+ status = "Orthopedic Attention Required"
131
+ # You can add more conditions here based on keyword matching
132
+
133
+ report_summary = f"Patient's Condition: {patient_condition}\nDisease: {disease}\nCondition Status: {status}\n\nReport Preview: {text[:300]}..." if text else "No readable text found in the PDF."
134
+
135
+ except Exception as e:
136
+ logger.error(f"Error reading PDF: {str(e)}")
137
+ report_summary = f"Error processing PDF: {str(e)}"
138
+ else:
139
+ report_summary = "Please upload a valid PDF file."
140
+
141
+ return report_summary
142
+
143
+ # Gradio Interface with enhanced UI using Tabs
144
  def create_interface():
145
+ with gr.Blocks() as demo:
146
+ custom_css = """
147
+ .gradio-container { background-color: #f4f6f9; border-radius: 15px; box-shadow: 0 4px 15px rgba(0,0,0,0.1); padding: 30px; font-family: 'Segoe UI', sans-serif; }
148
+ .title { font-size: 30px; text-align: center; color: #4C6A92; margin-bottom: 20px; }
149
+ .gradio-button { background-color: #3B82F6; color: white; border-radius: 10px; padding: 15px 30px; font-size: 16px; transition: background-color 0.3s; }
150
+ .gradio-button:hover { background-color: #2563EB; }
151
+ .result-box { background-color: #ffffff; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); margin-top: 20px; max-width: 100%; }
152
+ .result-list { padding-left: 20px; margin: 10px 0; }
153
+ .result-summary { font-size: 18px; color: #2F4F4F; font-weight: 500; }
154
+ .feedback-box { background-color: #F0FFF4; padding: 10px; border-left: 4px solid #38A169; border-radius: 5px; margin-top: 10px; }
155
+ """
156
+
157
  gr.Markdown("<h1 class='title'>RadiologyScan AI</h1>")
158
+ gr.Markdown("<p style='text-align: center; color: #666;'>AI-powered analysis for X-rays and patient reports</p>")
159
+
160
+ # Correctly provide interface_list to TabbedInterface
161
+ xray_tab = gr.Interface(fn=predict_xray, inputs=gr.Image(label="Upload X-ray", type="pil"), outputs=[gr.HTML(), gr.HTML(), gr.HTML()])
162
+ report_tab = gr.Interface(fn=analyze_report, inputs=gr.File(label="Upload Patient Report (PDF)", file_count="single"), outputs=gr.Textbox(label="Report Summary", interactive=False))
163
+
164
+ gr.TabbedInterface([xray_tab, report_tab], tab_names=["X-ray Analysis", "Report Analysis"]).launch(share=True)
165
+
166
+ demo = create_interface()