VaneshDev commited on
Commit
91f8060
·
verified ·
1 Parent(s): 25239fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -40
app.py CHANGED
@@ -2,7 +2,6 @@ import gradio as gr
2
  from PIL import Image
3
  import torch
4
  from torchvision import models, transforms
5
- import fitz # PyMuPDF
6
  import logging
7
  import os
8
 
@@ -10,7 +9,7 @@ import os
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
- # Condition list
14
  conditions = [
15
  "Normal", "Pneumonia", "Cancer", "TB", "Other",
16
  "Coronary Artery Disease", "Aortic Aneurysm", "Stroke", "Peripheral Artery Disease",
@@ -20,7 +19,7 @@ conditions = [
20
  "Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
21
  ]
22
 
23
- # Condition details
24
  condition_details = {
25
  "Normal": {"description": "No abnormal signs detected.", "recommendation": "Routine check-ups recommended."},
26
  "Pneumonia": {"description": "Lung inflammation detected, possibly infectious.", "recommendation": "Seek medical attention for treatment."},
@@ -49,19 +48,19 @@ condition_details = {
49
  "Diverticulitis": {"description": "Inflammation of diverticula in the colon.", "recommendation": "Gastroenterology consultation."}
50
  }
51
 
52
- # Load model (using a specialized X-ray model or pre-trained general model)
53
- model = models.densenet121(pretrained=True) # You can swap to a more specific X-ray model if available
54
  model.classifier = torch.nn.Linear(model.classifier.in_features, len(conditions)) # Adjust the classifier for our condition count
55
  model.eval()
56
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
  model.to(device)
58
 
59
- # Image preprocessing
60
  def preprocess_image(image):
61
  transform = transforms.Compose([
62
  transforms.Resize((224, 224)), # Resize to match the model input size
63
  transforms.ToTensor(),
64
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Standardize based on ImageNet values
65
  ])
66
  return transform(image).unsqueeze(0).to(device)
67
 
@@ -75,9 +74,11 @@ def predict_xray(image):
75
  with torch.no_grad():
76
  output = model(img_tensor)
77
 
 
78
  probs = torch.nn.functional.softmax(output, dim=1)[0]
79
  results = {conditions[i]: probs[i].item() * 100 for i in range(len(conditions))}
80
 
 
81
  top_condition = max(results, key=results.get)
82
  confidence = results[top_condition]
83
 
@@ -92,6 +93,7 @@ def predict_xray(image):
92
  </div>
93
  """
94
 
 
95
  info = condition_details.get(top_condition, condition_details["Other"])
96
  return f"""
97
  <div style="font-family:Arial">
@@ -106,34 +108,10 @@ def predict_xray(image):
106
  logger.error(f"Error in prediction: {e}")
107
  return f"Error: {str(e)}"
108
 
109
- # Analyze PDF report
110
- def analyze_report(file):
111
- if not file or not file.name.endswith(".pdf"):
112
- return "Please upload a valid PDF file."
113
- try:
114
- doc = fitz.open(file.name)
115
- text = "".join(page.get_text() for page in doc)
116
- doc.close()
117
-
118
- condition, disease, status = "Unclear", "Unknown", "Pending"
119
-
120
- if "stroke" in text.lower():
121
- condition, disease, status = "Stroke", "Brain Disorder", "Urgent Care Needed"
122
- elif "cancer" in text.lower():
123
- condition, disease, status = "Cancer", "Malignant Growth", "Consult Oncologist"
124
- elif "fracture" in text.lower():
125
- condition, disease, status = "Fracture", "Bone Injury", "Orthopedic Attention Required"
126
-
127
- preview = text[:300] + "..." if text else "No readable content."
128
- return f"Condition: {condition}\nDisease: {disease}\nStatus: {status}\n\nPreview:\n{preview}"
129
-
130
- except Exception as e:
131
- return f"Failed to process PDF: {str(e)}"
132
-
133
  # Gradio interface
134
  def create_interface():
135
  with gr.Blocks() as demo:
136
- gr.Markdown("<h1 style='text-align:center;'>🩻 RadiologyScan AI</h1><p style='text-align:center;'>AI-powered X-ray and Report Analysis</p>")
137
 
138
  with gr.Tabs():
139
  with gr.TabItem("X-ray Analysis"):
@@ -143,14 +121,7 @@ def create_interface():
143
  gr.Button("Analyze X-ray", elem_id="analyze_button", scale=0.3).click(predict_xray, inputs=img_input, outputs=summary_output)
144
  gr.Button("Clear", elem_id="clear_button", scale=0.3).click(lambda: [None, ""], inputs=None, outputs=[img_input, summary_output])
145
 
146
- with gr.TabItem("Report Analysis"):
147
- pdf_input = gr.File(label="Upload PDF Report", file_types=[".pdf"])
148
- summary_output_report = gr.Textbox(label="Summary Result", lines=5)
149
- with gr.Row():
150
- gr.Button("Analyze Report", elem_id="analyze_button", scale=0.3).click(analyze_report, inputs=pdf_input, outputs=summary_output_report)
151
- gr.Button("Clear", elem_id="clear_button", scale=0.3).click(lambda: [None, ""], inputs=None, outputs=[pdf_input, summary_output_report])
152
-
153
- return demo
154
 
155
  if __name__ == "__main__":
156
  demo = create_interface()
 
2
  from PIL import Image
3
  import torch
4
  from torchvision import models, transforms
 
5
  import logging
6
  import os
7
 
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
+ # Condition list (add your specific diseases here)
13
  conditions = [
14
  "Normal", "Pneumonia", "Cancer", "TB", "Other",
15
  "Coronary Artery Disease", "Aortic Aneurysm", "Stroke", "Peripheral Artery Disease",
 
19
  "Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
20
  ]
21
 
22
+ # Condition details for diagnosis (can be expanded with real data)
23
  condition_details = {
24
  "Normal": {"description": "No abnormal signs detected.", "recommendation": "Routine check-ups recommended."},
25
  "Pneumonia": {"description": "Lung inflammation detected, possibly infectious.", "recommendation": "Seek medical attention for treatment."},
 
48
  "Diverticulitis": {"description": "Inflammation of diverticula in the colon.", "recommendation": "Gastroenterology consultation."}
49
  }
50
 
51
+ # Load a pre-trained DenseNet121 model
52
+ model = models.densenet121(pretrained=True)
53
  model.classifier = torch.nn.Linear(model.classifier.in_features, len(conditions)) # Adjust the classifier for our condition count
54
  model.eval()
55
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
  model.to(device)
57
 
58
+ # Image preprocessing function
59
  def preprocess_image(image):
60
  transform = transforms.Compose([
61
  transforms.Resize((224, 224)), # Resize to match the model input size
62
  transforms.ToTensor(),
63
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Standard ImageNet normalization
64
  ])
65
  return transform(image).unsqueeze(0).to(device)
66
 
 
74
  with torch.no_grad():
75
  output = model(img_tensor)
76
 
77
+ # Get probabilities for each condition
78
  probs = torch.nn.functional.softmax(output, dim=1)[0]
79
  results = {conditions[i]: probs[i].item() * 100 for i in range(len(conditions))}
80
 
81
+ # Identify the condition with the highest probability
82
  top_condition = max(results, key=results.get)
83
  confidence = results[top_condition]
84
 
 
93
  </div>
94
  """
95
 
96
+ # Fetch details for the identified condition
97
  info = condition_details.get(top_condition, condition_details["Other"])
98
  return f"""
99
  <div style="font-family:Arial">
 
108
  logger.error(f"Error in prediction: {e}")
109
  return f"Error: {str(e)}"
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # Gradio interface
112
  def create_interface():
113
  with gr.Blocks() as demo:
114
+ gr.Markdown("<h1 style='text-align:center;'>🩻 RadiologyScan AI</h1><p style='text-align:center;'>AI-powered X-ray Analysis</p>")
115
 
116
  with gr.Tabs():
117
  with gr.TabItem("X-ray Analysis"):
 
121
  gr.Button("Analyze X-ray", elem_id="analyze_button", scale=0.3).click(predict_xray, inputs=img_input, outputs=summary_output)
122
  gr.Button("Clear", elem_id="clear_button", scale=0.3).click(lambda: [None, ""], inputs=None, outputs=[img_input, summary_output])
123
 
124
+ return demo
 
 
 
 
 
 
 
125
 
126
  if __name__ == "__main__":
127
  demo = create_interface()