VaneshDev commited on
Commit
286c23d
·
verified ·
1 Parent(s): 04f8061

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -98
app.py CHANGED
@@ -4,33 +4,52 @@ import torch
4
  from torchvision import models, transforms
5
  import PyPDF2
6
  import logging
 
7
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.DEBUG)
10
  logger = logging.getLogger(__name__)
11
 
12
- # Load the pre-trained model (DenseNet121, suitable for medical imaging)
13
- model = models.densenet121(pretrained=True) # Using ImageNet weights as a starting point
 
 
 
 
 
 
 
 
 
 
14
  num_features = model.classifier.in_features
15
- model.classifier = torch.nn.Linear(num_features, 5) # Output layer for 5 conditions
16
  model.eval()
17
 
18
- # Define device (CPU or GPU)
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model = model.to(device)
21
 
22
- # Define image preprocessing function with normalization
 
 
 
 
 
 
 
 
23
  def preprocess_image(image):
24
  transform = transforms.Compose([
25
  transforms.Resize((224, 224)),
26
  transforms.ToTensor(),
27
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # ImageNet normalization
28
  ])
29
  image_tensor = transform(image).unsqueeze(0).to(device)
30
  logger.debug(f"Preprocessed image tensor shape: {image_tensor.shape}")
31
  return image_tensor
32
 
33
- # Define a prediction function for X-ray images with detailed output
34
  def predict_xray(image):
35
  try:
36
  if image is None:
@@ -39,51 +58,49 @@ def predict_xray(image):
39
  image_tensor = preprocess_image(image)
40
  with torch.no_grad():
41
  outputs = model(image_tensor)
42
- probs = torch.nn.functional.softmax(outputs, dim=1)[0] # Softmax over the 5 classes
 
43
 
44
- # Define the conditions
45
- conditions = ["Normal", "Pneumonia", "Cancer", "TB", "Other"]
46
- results = {conditions[i]: float(probs[i].cpu().numpy()) * 100 for i in range(5)}
47
-
48
- # Determine the most likely condition and confidence
49
  most_likely_condition = max(results, key=results.get)
50
  confidence = results[most_likely_condition]
51
 
52
- # Generate summary
53
  summary = f"**Summary**: Based on the X-ray analysis, the most likely diagnosis is: <b>{most_likely_condition}</b> with a confidence of <b>{confidence:.2f}%</b>."
54
 
55
- # Condition details with enhanced descriptions
56
  condition_details = {
57
- "Normal": {
58
- "description": "The X-ray shows no abnormal signs, indicating healthy lung tissue with clear structures.",
59
- "recommendation": "No immediate action required. Schedule routine check-ups to monitor lung health."
60
- },
61
- "Pneumonia": {
62
- "description": "Pneumonia is detected, showing lung inflammation, possibly due to bacterial or viral infection, with visible opacities.",
63
- "recommendation": "Seek medical attention promptly; treatment may include antibiotics or antiviral medication."
64
- },
65
- "Cancer": {
66
- "description": "Suspicious masses or nodules suggest lung cancer, requiring advanced imaging (e.g., CT) for confirmation.",
67
- "recommendation": "Urgently consult an oncologist for a biopsy and personalized treatment plan."
68
- },
69
- "TB": {
70
- "description": "Tuberculosis is indicated by cavitary lesions or consolidation, a contagious bacterial infection.",
71
- "recommendation": "Contact a healthcare provider immediately for a treatment regimen, likely involving multiple antibiotics."
72
- },
73
- "Other": {
74
- "description": "Unclear abnormalities detected; could indicate conditions like fibrosis or heart-related issues.",
75
- "recommendation": "Refer to a radiologist for specialized imaging and diagnosis."
76
- }
 
 
 
 
 
77
  }
78
 
79
- # Detailed results in a structured format
80
  detailed_results = "<ul class='result-list'>"
81
  for condition, prob in results.items():
82
  detailed_results += f"<li><b>{condition}:</b> {prob:.2f}%</li>"
83
  detailed_results += "</ul>"
84
 
85
- # Additional feedback based on the condition
86
- additional_feedback = condition_details.get(most_likely_condition, "Please consult a medical professional for a detailed evaluation.")
87
 
88
  logger.info(f"Prediction: {most_likely_condition} with confidence {confidence:.2f}%")
89
  return summary, detailed_results, additional_feedback
@@ -92,7 +109,7 @@ def predict_xray(image):
92
  logger.error(f"Error in predict_xray: {str(e)}")
93
  return f"Error: {str(e)}", "", ""
94
 
95
- # Define a function to read and analyze patient reports (PDFs)
96
  def analyze_report(file):
97
  text = ""
98
  if file and file.name.endswith(".pdf"):
@@ -111,82 +128,36 @@ def analyze_report(file):
111
  # Gradio Interface with enhanced UI
112
  def create_interface():
113
  with gr.Blocks() as demo:
114
- # Custom CSS for UI enhancement
115
  custom_css = """
116
- .gradio-container {
117
- background-color: #f4f6f9;
118
- border-radius: 15px;
119
- box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
120
- padding: 30px;
121
- font-family: 'Segoe UI', sans-serif;
122
- }
123
- .title {
124
- font-size: 30px;
125
- text-align: center;
126
- color: #4C6A92;
127
- margin-bottom: 20px;
128
- }
129
- .gradio-button {
130
- background-color: #3B82F6;
131
- color: white;
132
- border-radius: 10px;
133
- padding: 15px 30px;
134
- font-size: 16px;
135
- transition: background-color 0.3s;
136
- }
137
- .gradio-button:hover {
138
- background-color: #2563EB;
139
- }
140
- .result-box {
141
- background-color: #ffffff;
142
- border-radius: 10px;
143
- padding: 20px;
144
- box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
145
- margin-top: 20px;
146
- max-width: 100%;
147
- }
148
- .result-list {
149
- padding-left: 20px;
150
- margin: 10px 0;
151
- }
152
- .result-summary {
153
- font-size: 18px;
154
- color: #2F4F4F;
155
- font-weight: 500;
156
- }
157
- .feedback-box {
158
- background-color: #F0FFF4;
159
- padding: 10px;
160
- border-left: 4px solid #38A169;
161
- border-radius: 5px;
162
- margin-top: 10px;
163
- }
164
  """
165
 
166
- # Title and description
167
  gr.Markdown("<h1 class='title'>RadiologyScan AI</h1>")
168
- gr.Markdown("<p style='text-align: center; color: #666;'>Advanced X-ray and patient report analysis powered by AI</p>")
169
 
170
- # Upload section with layout
171
  with gr.Row():
172
  with gr.Column(scale=1):
173
- xray_input = gr.Image(label="Upload Chest X-ray", type="pil", elem_id="xray-input")
174
  with gr.Column(scale=1):
175
  report_input = gr.File(label="Upload Patient Report (PDF)", file_count="single", elem_id="report-input")
176
 
177
- # Buttons for analysis
178
  with gr.Row():
179
  predict_button = gr.Button("Analyze X-ray", elem_classes="gradio-button")
180
  report_button = gr.Button("Analyze Report", elem_classes="gradio-button")
181
 
182
- # Results section
183
  with gr.Column():
184
  xray_output = gr.HTML(label="X-ray Diagnosis Summary", elem_classes="result-box")
185
  xray_result = gr.HTML(label="Detailed X-ray Results", elem_classes="result-box")
186
  additional_feedback = gr.HTML(label="Additional Feedback", elem_classes="result-box feedback-box")
187
  report_output = gr.Textbox(label="Report Summary", interactive=False, elem_classes="result-box")
188
 
189
- # Event handlers
190
  predict_button.click(
191
  fn=predict_xray,
192
  inputs=xray_input,
@@ -195,11 +166,10 @@ def create_interface():
195
  report_button.click(
196
  fn=analyze_report,
197
  inputs=report_input,
198
- outputs=report_output
199
- )
200
 
201
  return demo
202
 
203
- # Launch the Gradio interface
204
  demo = create_interface()
205
  demo.launch(share=True)
 
4
  from torchvision import models, transforms
5
  import PyPDF2
6
  import logging
7
+ import os
8
 
9
  # Set up logging
10
  logging.basicConfig(level=logging.DEBUG)
11
  logger = logging.getLogger(__name__)
12
 
13
+ # Define conditions based on provided radiology information
14
+ conditions = [
15
+ "Normal", "Pneumonia", "Cancer", "TB", "Other",
16
+ "Coronary Artery Disease", "Aortic Aneurysm", "Stroke", "Peripheral Artery Disease",
17
+ "Brain Tumor", "Alzheimer's Disease", "Multiple Sclerosis", "Epilepsy",
18
+ "COPD", "Lung Cancer", "Pulmonary Embolism",
19
+ "Fractures", "Arthritis", "Osteoporosis",
20
+ "Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
21
+ ]
22
+
23
+ # Load and configure the model
24
+ model = models.densenet121(pretrained=False) # Start without pre-trained weights for custom training
25
  num_features = model.classifier.in_features
26
+ model.classifier = torch.nn.Linear(num_features, len(conditions)) # Output for all 16 conditions
27
  model.eval()
28
 
29
+ # Define device
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  model = model.to(device)
32
 
33
+ # Load model state if available, otherwise initialize
34
+ model_path = "xray_model.pth"
35
+ if os.path.exists(model_path):
36
+ model.load_state_dict(torch.load(model_path))
37
+ logger.info(f"Loaded model from {model_path}")
38
+ else:
39
+ logger.info("No pre-trained model found. Initializing with random weights. Training required.")
40
+
41
+ # Define image preprocessing function
42
  def preprocess_image(image):
43
  transform = transforms.Compose([
44
  transforms.Resize((224, 224)),
45
  transforms.ToTensor(),
46
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
47
  ])
48
  image_tensor = transform(image).unsqueeze(0).to(device)
49
  logger.debug(f"Preprocessed image tensor shape: {image_tensor.shape}")
50
  return image_tensor
51
 
52
+ # Define prediction function with detailed output
53
  def predict_xray(image):
54
  try:
55
  if image is None:
 
58
  image_tensor = preprocess_image(image)
59
  with torch.no_grad():
60
  outputs = model(image_tensor)
61
+ probs = torch.nn.functional.softmax(outputs, dim=1)[0] # Softmax over all conditions
62
+ results = {conditions[i]: float(probs[i].cpu().numpy()) * 100 for i in range(len(conditions))}
63
 
 
 
 
 
 
64
  most_likely_condition = max(results, key=results.get)
65
  confidence = results[most_likely_condition]
66
 
 
67
  summary = f"**Summary**: Based on the X-ray analysis, the most likely diagnosis is: <b>{most_likely_condition}</b> with a confidence of <b>{confidence:.2f}%</b>."
68
 
69
+ # Enhanced condition details
70
  condition_details = {
71
+ "Normal": {"description": "No abnormal signs detected.", "recommendation": "Routine check-ups recommended."},
72
+ "Pneumonia": {"description": "Lung inflammation detected, possibly infectious.", "recommendation": "Seek medical attention for treatment."},
73
+ "Cancer": {"description": "Suspicious masses suggest cancer; further imaging needed.", "recommendation": "Consult an oncologist."},
74
+ "TB": {"description": "Cavitary lesions indicate tuberculosis.", "recommendation": "Immediate medical evaluation required."},
75
+ "Other": {"description": "Unclear abnormality; further investigation needed.", "recommendation": "Consult a radiologist."},
76
+ "Coronary Artery Disease": {"description": "Blockages in heart arteries detected.", "recommendation": "Cardiology consultation advised."},
77
+ "Aortic Aneurysm": {"description": "Aortic dilation observed.", "recommendation": "Monitor with imaging; surgical consult if large."},
78
+ "Stroke": {"description": "Brain damage from stroke detected.", "recommendation": "Urgent neurological care needed."},
79
+ "Peripheral Artery Disease": {"description": "Reduced limb blood flow observed.", "recommendation": "Vascular specialist consultation."},
80
+ "Brain Tumor": {"description": "Abnormal growth in brain detected.", "recommendation": "Neurological evaluation required."},
81
+ "Alzheimer's Disease": {"description": "Brain atrophy suggestive of Alzheimer's.", "recommendation": "Neurological assessment."},
82
+ "Multiple Sclerosis": {"description": "Lesions in brain/spinal cord detected.", "recommendation": "Consult neurologist."},
83
+ "Epilepsy": {"description": "Seizure source possibly identified.", "recommendation": "Neurological workup needed."},
84
+ "COPD": {"description": "Lung damage from COPD observed.", "recommendation": "Pulmonary consultation."},
85
+ "Lung Cancer": {"description": "Nodules suggest lung cancer.", "recommendation": "Oncologist referral."},
86
+ "Pulmonary Embolism": {"description": "Blood clot in lungs detected.", "recommendation": "Emergency care required."},
87
+ "Fractures": {"description": "Bone break detected.", "recommendation": "Orthopedic evaluation."},
88
+ "Arthritis": {"description": "Joint damage observed.", "recommendation": "Rheumatology consult."},
89
+ "Osteoporosis": {"description": "Reduced bone density detected.", "recommendation": "Bone health assessment."},
90
+ "Appendicitis": {"description": "Inflammation of appendix observed.", "recommendation": "Surgical evaluation."},
91
+ "Gallstones": {"description": "Stones in gallbladder detected.", "recommendation": "Gastroenterology consult."},
92
+ "Kidney Stones": {"description": "Stones in kidneys observed.", "recommendation": "Urology evaluation."},
93
+ "Infections": {"description": "Signs of infection detected.", "recommendation": "Infectious disease consult."},
94
+ "Abdominal Aortic Aneurysm": {"description": "Abdominal aortic dilation observed.", "recommendation": "Vascular surgery consult."},
95
+ "Diverticulitis": {"description": "Digestive tract inflammation detected.", "recommendation": "Gastroenterology evaluation."}
96
  }
97
 
 
98
  detailed_results = "<ul class='result-list'>"
99
  for condition, prob in results.items():
100
  detailed_results += f"<li><b>{condition}:</b> {prob:.2f}%</li>"
101
  detailed_results += "</ul>"
102
 
103
+ additional_feedback = f"<div class='feedback-box'><b>Description:</b> {condition_details[most_likely_condition]['description']}<br><b>Recommendation:</b> {condition_details[most_likely_condition]['recommendation']}</div>"
 
104
 
105
  logger.info(f"Prediction: {most_likely_condition} with confidence {confidence:.2f}%")
106
  return summary, detailed_results, additional_feedback
 
109
  logger.error(f"Error in predict_xray: {str(e)}")
110
  return f"Error: {str(e)}", "", ""
111
 
112
+ # Define function to read and analyze patient reports (PDFs)
113
  def analyze_report(file):
114
  text = ""
115
  if file and file.name.endswith(".pdf"):
 
128
  # Gradio Interface with enhanced UI
129
  def create_interface():
130
  with gr.Blocks() as demo:
 
131
  custom_css = """
132
+ .gradio-container { background-color: #f4f6f9; border-radius: 15px; box-shadow: 0 4px 15px rgba(0,0,0,0.1); padding: 30px; font-family: 'Segoe UI', sans-serif; }
133
+ .title { font-size: 30px; text-align: center; color: #4C6A92; margin-bottom: 20px; }
134
+ .gradio-button { background-color: #3B82F6; color: white; border-radius: 10px; padding: 15px 30px; font-size: 16px; transition: background-color 0.3s; }
135
+ .gradio-button:hover { background-color: #2563EB; }
136
+ .result-box { background-color: #ffffff; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); margin-top: 20px; max-width: 100%; }
137
+ .result-list { padding-left: 20px; margin: 10px 0; }
138
+ .result-summary { font-size: 18px; color: #2F4F4F; font-weight: 500; }
139
+ .feedback-box { background-color: #F0FFF4; padding: 10px; border-left: 4px solid #38A169; border-radius: 5px; margin-top: 10px; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  """
141
 
 
142
  gr.Markdown("<h1 class='title'>RadiologyScan AI</h1>")
143
+ gr.Markdown("<p style='text-align: center; color: #666;'>AI-powered analysis for X-rays and patient reports</p>")
144
 
 
145
  with gr.Row():
146
  with gr.Column(scale=1):
147
+ xray_input = gr.Image(label="Upload X-ray", type="pil", elem_id="xray-input")
148
  with gr.Column(scale=1):
149
  report_input = gr.File(label="Upload Patient Report (PDF)", file_count="single", elem_id="report-input")
150
 
 
151
  with gr.Row():
152
  predict_button = gr.Button("Analyze X-ray", elem_classes="gradio-button")
153
  report_button = gr.Button("Analyze Report", elem_classes="gradio-button")
154
 
 
155
  with gr.Column():
156
  xray_output = gr.HTML(label="X-ray Diagnosis Summary", elem_classes="result-box")
157
  xray_result = gr.HTML(label="Detailed X-ray Results", elem_classes="result-box")
158
  additional_feedback = gr.HTML(label="Additional Feedback", elem_classes="result-box feedback-box")
159
  report_output = gr.Textbox(label="Report Summary", interactive=False, elem_classes="result-box")
160
 
 
161
  predict_button.click(
162
  fn=predict_xray,
163
  inputs=xray_input,
 
166
  report_button.click(
167
  fn=analyze_report,
168
  inputs=report_input,
169
+ outputs=report_output)
 
170
 
171
  return demo
172
 
173
+ # Launch the interface and save model after training (to be implemented by user)
174
  demo = create_interface()
175
  demo.launch(share=True)