VaneshDev commited on
Commit
29b04e5
·
verified ·
1 Parent(s): 897c5ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -111
app.py CHANGED
@@ -2,15 +2,15 @@ import gradio as gr
2
  from PIL import Image
3
  import torch
4
  from torchvision import models, transforms
5
- import fitz # PyMuPDF for better PDF parsing
6
  import logging
7
  import os
8
 
9
- # Set up logging
10
- logging.basicConfig(level=logging.DEBUG) # Set to DEBUG for detailed output
11
  logger = logging.getLogger(__name__)
12
 
13
- # Define conditions
14
  conditions = [
15
  "Normal", "Pneumonia", "Cancer", "TB", "Other",
16
  "Coronary Artery Disease", "Aortic Aneurysm", "Stroke", "Peripheral Artery Disease",
@@ -20,7 +20,7 @@ conditions = [
20
  "Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
21
  ]
22
 
23
- # Define condition details for all conditions
24
  condition_details = {
25
  "Normal": {"description": "No abnormal signs detected.", "recommendation": "Routine check-ups recommended."},
26
  "Pneumonia": {"description": "Lung inflammation detected, possibly infectious.", "recommendation": "Seek medical attention for treatment."},
@@ -52,155 +52,105 @@ condition_details = {
52
  # Load and configure the model
53
  try:
54
  model = models.densenet121(weights="IMAGENET1K_V1")
55
- num_features = model.classifier.in_features
56
- model.classifier = torch.nn.Linear(num_features, len(conditions))
57
- model.eval()
58
  except AttributeError:
59
  model = models.densenet121(pretrained=True)
60
- num_features = model.classifier.in_features
61
- model.classifier = torch.nn.Linear(num_features, len(conditions))
62
- model.eval()
63
 
64
- # Define device
 
 
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
- model = model.to(device)
67
- logger.info(f"Using device: {device}")
68
 
69
- # Load model state if available
70
  model_path = os.getenv("MODEL_PATH", "xray_model.pth")
71
  if os.path.exists(model_path):
72
  try:
73
  model.load_state_dict(torch.load(model_path, map_location=device))
74
- logger.info(f"Loaded model from {model_path}")
75
  except Exception as e:
76
- logger.warning(f"Failed to load model from {model_path}: {str(e)}. Using random weights.")
77
  else:
78
- logger.info("No pre-trained model found. Initializing with random weights. Training required.")
79
 
80
- # Define image preprocessing function
81
  def preprocess_image(image):
82
- if not isinstance(image, Image.Image):
83
- logger.error("Invalid image format. Expected PIL Image.")
84
- raise ValueError("Uploaded file is not a valid image.")
85
-
86
  transform = transforms.Compose([
87
  transforms.Resize((224, 224)),
88
  transforms.ToTensor(),
89
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
90
  ])
91
- image_tensor = transform(image).unsqueeze(0).to(device)
92
- logger.debug(f"Preprocessed image tensor shape: {image_tensor.shape}")
93
- return image_tensor
94
 
95
- # Define prediction function
96
  def predict_xray(image):
97
  try:
98
  if image is None:
99
- return "Error: No image uploaded.", "", ""
100
 
101
- image_tensor = preprocess_image(image)
102
  with torch.no_grad():
103
- outputs = model(image_tensor)
104
- probs = torch.nn.functional.softmax(outputs, dim=1)[0]
105
- results = {conditions[i]: float(probs[i].cpu().numpy()) * 100 for i in range(len(conditions))}
106
-
107
- most_likely_condition = max(results, key=results.get)
108
- confidence = results[most_likely_condition]
109
-
110
- if most_likely_condition not in condition_details:
111
- most_likely_condition = "Other"
112
- confidence = 0.0
113
-
114
- summary = f"**Summary**: Based on the X-ray analysis, the most likely diagnosis is: <b>{most_likely_condition}</b> with a confidence of <b>{confidence:.2f}%</b>."
115
- detailed_results = "<ul class='result-list'>"
116
- for condition, prob in results.items():
117
- detailed_results += f"<li><b>{condition}:</b> {prob:.2f}%</li>"
118
- detailed_results += "</ul>"
119
- additional_feedback = f"<div class='feedback-box'><b>Description:</b> {condition_details[most_likely_condition]['description']}<br><b>Recommendation:</b> {condition_details[most_likely_condition]['recommendation']}</div>"
120
-
121
- logger.info(f"Prediction: {most_likely_condition} with confidence {confidence:.2f}%")
122
- return summary, detailed_results, additional_feedback
123
-
124
  except Exception as e:
125
- logger.error(f"Error in predict_xray: {str(e)}")
126
- return f"Error: {str(e)}", "", ""
127
 
128
- # Enhanced function to analyze patient reports (PDFs)
129
  def analyze_report(file):
130
  if not file or not file.name.endswith(".pdf"):
131
- return "Please upload a valid PDF file."
132
-
133
- text = ""
134
- patient_condition = "Unclear"
135
- disease = "Unknown"
136
- status = "Pending further tests"
137
 
138
  try:
139
- pdf_reader = fitz.open(file.name)
140
- for page in pdf_reader:
141
- text += page.get_text("text")
142
- pdf_reader.close()
 
143
 
144
  if "stroke" in text.lower():
145
- patient_condition = "Stroke"
146
- disease = "Brain Disorder"
147
- status = "Urgent Care Needed"
148
  elif "cancer" in text.lower():
149
- patient_condition = "Cancer"
150
- disease = "Malignant Growth"
151
- status = "Consult Oncologist"
152
  elif "fracture" in text.lower():
153
- patient_condition = "Fracture"
154
- disease = "Bone Injury"
155
- status = "Orthopedic Attention Required"
156
 
157
- report_summary = f"Patient's Condition: {patient_condition}\nDisease: {disease}\nCondition Status: {status}\n\nReport Preview: {text[:300]}..." if text else "No readable text found in the PDF."
158
- return report_summary
159
 
160
  except Exception as e:
161
- logger.error(f"Error reading PDF: {str(e)}")
162
- return f"Error processing PDF: {str(e)}"
163
 
164
- # Gradio Interface with Tabs
165
  def create_interface():
166
- logger.debug("Initializing Gradio interface")
167
- # Minimal CSS to avoid styling issues
168
- custom_css = """
169
- .title { font-size: 30px; text-align: center; color: #4C6A92; }
170
- .subtitle { text-align: center; color: #666; }
171
- """
172
- with gr.Blocks(css=custom_css) as demo:
173
- gr.Markdown("<h1 class='title'>RadiologyScan AI</h1>")
174
- gr.Markdown("<p class='subtitle'>AI-powered analysis for X-rays and patient reports</p>")
175
 
176
  with gr.Tabs():
177
  with gr.TabItem("X-ray Analysis"):
178
- image_input = gr.Image(label="Upload X-ray", type="pil")
179
- output_summary = gr.HTML(label="Summary")
180
- output_details = gr.HTML(label="Detailed Results")
181
- output_feedback = gr.HTML(label="Additional Feedback")
182
- gr.Button("Analyze X-ray").click(
183
- fn=predict_xray,
184
- inputs=image_input,
185
- outputs=[output_summary, output_details, output_feedback]
186
- )
187
 
188
  with gr.TabItem("Report Analysis"):
189
- file_input = gr.File(label="Upload Patient Report (PDF)", file_count="single")
190
- output_report = gr.Textbox(label="Report Summary", interactive=False)
191
- gr.Button("Analyze Report").click(
192
- fn=analyze_report,
193
- inputs=file_input,
194
- outputs=output_report
195
- )
196
- logger.debug("Gradio interface initialized")
197
  return demo
198
 
 
199
  if __name__ == "__main__":
200
- logger.debug("Starting Gradio application")
201
- try:
202
- demo = create_interface()
203
- demo.launch(server_port=7860, ssr_mode=False) # Explicit port, disable SSR
204
- logger.debug("Gradio application launched")
205
- except Exception as e:
206
- logger.error(f"Failed to launch Gradio application: {str(e)}")
 
2
  from PIL import Image
3
  import torch
4
  from torchvision import models, transforms
5
+ import fitz # PyMuPDF
6
  import logging
7
  import os
8
 
9
+ # Logging setup
10
+ logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ # List of possible conditions
14
  conditions = [
15
  "Normal", "Pneumonia", "Cancer", "TB", "Other",
16
  "Coronary Artery Disease", "Aortic Aneurysm", "Stroke", "Peripheral Artery Disease",
 
20
  "Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
21
  ]
22
 
23
+ # Details for each condition
24
  condition_details = {
25
  "Normal": {"description": "No abnormal signs detected.", "recommendation": "Routine check-ups recommended."},
26
  "Pneumonia": {"description": "Lung inflammation detected, possibly infectious.", "recommendation": "Seek medical attention for treatment."},
 
52
  # Load and configure the model
53
  try:
54
  model = models.densenet121(weights="IMAGENET1K_V1")
 
 
 
55
  except AttributeError:
56
  model = models.densenet121(pretrained=True)
 
 
 
57
 
58
+ model.classifier = torch.nn.Linear(model.classifier.in_features, len(conditions))
59
+ model.eval()
60
+
61
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+ model.to(device)
 
63
 
64
+ # Load trained weights if available
65
  model_path = os.getenv("MODEL_PATH", "xray_model.pth")
66
  if os.path.exists(model_path):
67
  try:
68
  model.load_state_dict(torch.load(model_path, map_location=device))
69
+ logger.info("Loaded custom model weights.")
70
  except Exception as e:
71
+ logger.warning(f"Failed to load model weights: {e}")
72
  else:
73
+ logger.info("No model weights found. Using random weights.")
74
 
75
+ # Preprocess uploaded image
76
  def preprocess_image(image):
 
 
 
 
77
  transform = transforms.Compose([
78
  transforms.Resize((224, 224)),
79
  transforms.ToTensor(),
80
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
81
  ])
82
+ return transform(image).unsqueeze(0).to(device)
 
 
83
 
84
+ # Predict X-ray condition
85
  def predict_xray(image):
86
  try:
87
  if image is None:
88
+ return "Please upload an image."
89
 
90
+ img_tensor = preprocess_image(image)
91
  with torch.no_grad():
92
+ output = model(img_tensor)
93
+ probs = torch.nn.functional.softmax(output, dim=1)[0]
94
+ results = {conditions[i]: probs[i].item() * 100 for i in range(len(conditions))}
95
+
96
+ top_condition = max(results, key=results.get)
97
+ confidence = results[top_condition]
98
+
99
+ info = condition_details.get(top_condition, condition_details["Other"])
100
+ return f"""
101
+ <div style="font-family:Arial">
102
+ <h3>Prediction: <span style="color:#2A9D8F;">{top_condition}</span></h3>
103
+ <p><b>Confidence:</b> {confidence:.2f}%</p>
104
+ <p><b>Description:</b> {info['description']}</p>
105
+ <p><b>Recommendation:</b> {info['recommendation']}</p>
106
+ </div>
107
+ """
 
 
 
 
 
108
  except Exception as e:
109
+ return f"Prediction failed: {str(e)}"
 
110
 
111
+ # Analyze PDF medical report
112
  def analyze_report(file):
113
  if not file or not file.name.endswith(".pdf"):
114
+ return "Please upload a valid PDF report."
 
 
 
 
 
115
 
116
  try:
117
+ doc = fitz.open(file.name)
118
+ text = "".join(page.get_text() for page in doc)
119
+ doc.close()
120
+
121
+ condition, disease, status = "Unclear", "Unknown", "Pending evaluation"
122
 
123
  if "stroke" in text.lower():
124
+ condition, disease, status = "Stroke", "Brain Disorder", "Urgent Care Needed"
 
 
125
  elif "cancer" in text.lower():
126
+ condition, disease, status = "Cancer", "Malignant Growth", "Consult Oncologist"
 
 
127
  elif "fracture" in text.lower():
128
+ condition, disease, status = "Fracture", "Bone Injury", "Orthopedic Attention Required"
 
 
129
 
130
+ return f"Condition: {condition}\nDisease: {disease}\nStatus: {status}\n\nPreview:\n{text[:300]}..."
 
131
 
132
  except Exception as e:
133
+ return f"Failed to analyze PDF: {str(e)}"
 
134
 
135
+ # Gradio interface
136
  def create_interface():
137
+ with gr.Blocks() as demo:
138
+ gr.Markdown("<h1 style='text-align:center;'>🩻 RadiologyScan AI</h1><p style='text-align:center;'>AI-powered X-ray and PDF report analysis</p>")
 
 
 
 
 
 
 
139
 
140
  with gr.Tabs():
141
  with gr.TabItem("X-ray Analysis"):
142
+ xray_input = gr.Image(label="Upload Chest X-ray", type="pil")
143
+ xray_output = gr.HTML()
144
+ gr.Button("Analyze X-ray").click(predict_xray, inputs=xray_input, outputs=xray_output)
 
 
 
 
 
 
145
 
146
  with gr.TabItem("Report Analysis"):
147
+ pdf_input = gr.File(label="Upload Medical Report (PDF)", file_types=[".pdf"])
148
+ pdf_output = gr.Textbox(label="Report Summary", lines=10)
149
+ gr.Button("Analyze Report").click(analyze_report, inputs=pdf_input, outputs=pdf_output)
150
+
 
 
 
 
151
  return demo
152
 
153
+ # Launch app
154
  if __name__ == "__main__":
155
+ demo = create_interface()
156
+ demo.launch(server_port=7860, ssr_mode=False)