VaneshDev commited on
Commit
2644f95
·
verified ·
1 Parent(s): b606e91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -29
app.py CHANGED
@@ -49,28 +49,13 @@ condition_details = {
49
  "Diverticulitis": {"description": "Inflammation of diverticula in the colon.", "recommendation": "Gastroenterology consultation."}
50
  }
51
 
52
- # Load model
53
- try:
54
- model = models.densenet121(weights="IMAGENET1K_V1")
55
- except AttributeError:
56
- model = models.densenet121(pretrained=True)
57
-
58
- model.classifier = torch.nn.Linear(model.classifier.in_features, len(conditions))
59
  model.eval()
60
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
  model.to(device)
62
 
63
- # Load model weights if available
64
- model_path = os.getenv("MODEL_PATH", "xray_model.pth")
65
- if os.path.exists(model_path):
66
- try:
67
- model.load_state_dict(torch.load(model_path, map_location=device))
68
- logger.info("Model loaded from file.")
69
- except Exception as e:
70
- logger.warning(f"Failed to load model weights: {e}")
71
- else:
72
- logger.info("No custom model weights found.")
73
-
74
  # Image preprocessing
75
  def preprocess_image(image):
76
  transform = transforms.Compose([
@@ -80,7 +65,7 @@ def preprocess_image(image):
80
  ])
81
  return transform(image).unsqueeze(0).to(device)
82
 
83
- # X-ray prediction function with confidence threshold
84
  def predict_xray(image):
85
  try:
86
  if image is None:
@@ -96,23 +81,24 @@ def predict_xray(image):
96
  top_condition = max(results, key=results.get)
97
  confidence = results[top_condition]
98
 
 
99
  if confidence < 50:
100
  return f"""
101
  <div style="font-family:Arial">
102
- <h3>Prediction: <span style="color:#D62828;">Uncertain</span></h3>
103
- <p><b>Confidence:</b> {confidence:.2f}%</p>
104
- <p><b>Note:</b> The model is not confident enough to provide a clear diagnosis.</p>
105
- <p><b>Recommendation:</b> Please consult a radiologist or upload a better-quality image.</p>
106
  </div>
107
  """
108
 
109
  info = condition_details.get(top_condition, condition_details["Other"])
110
  return f"""
111
  <div style="font-family:Arial">
112
- <h3>Prediction: <span style="color:#2A9D8F;">{top_condition}</span></h3>
113
- <p><b>Confidence:</b> {confidence:.2f}%</p>
114
- <p><b>Description:</b> {info['description']}</p>
115
- <p><b>Recommendation:</b> {info['recommendation']}</p>
116
  </div>
117
  """
118
 
@@ -153,12 +139,12 @@ def create_interface():
153
  with gr.TabItem("X-ray Analysis"):
154
  img_input = gr.Image(label="Upload Chest X-ray", type="pil")
155
  summary_output = gr.HTML(label="Summary Result")
156
- gr.Button("Analyze X-ray", elem_id="analyze_button", scale=0.5).click(predict_xray, inputs=img_input, outputs=summary_output)
157
 
158
  with gr.TabItem("Report Analysis"):
159
  pdf_input = gr.File(label="Upload PDF Report", file_types=[".pdf"])
160
  summary_output_report = gr.Textbox(label="Summary Result", lines=5)
161
- gr.Button("Analyze Report", elem_id="analyze_button", scale=0.5).click(analyze_report, inputs=pdf_input, outputs=summary_output_report)
162
 
163
  return demo
164
 
 
49
  "Diverticulitis": {"description": "Inflammation of diverticula in the colon.", "recommendation": "Gastroenterology consultation."}
50
  }
51
 
52
+ # Load model (using a smaller model like MobileNetV2 for faster inference)
53
+ model = models.mobilenet_v2(pretrained=True)
54
+ model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, len(conditions)) # Adjust the classifier for our condition count
 
 
 
 
55
  model.eval()
56
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
  model.to(device)
58
 
 
 
 
 
 
 
 
 
 
 
 
59
  # Image preprocessing
60
  def preprocess_image(image):
61
  transform = transforms.Compose([
 
65
  ])
66
  return transform(image).unsqueeze(0).to(device)
67
 
68
+ # X-ray prediction function with summary output
69
  def predict_xray(image):
70
  try:
71
  if image is None:
 
81
  top_condition = max(results, key=results.get)
82
  confidence = results[top_condition]
83
 
84
+ # Construct a summary based on prediction
85
  if confidence < 50:
86
  return f"""
87
  <div style="font-family:Arial">
88
+ <h3>Summary</h3>
89
+ <p><b>Disease Identified:</b> Uncertain</p>
90
+ <p><b>Cause/Status:</b> The model is not confident enough to provide a clear diagnosis.</p>
91
+ <p><b>Treatment/Recommendation:</b> Please consult a radiologist or upload a better-quality image for better accuracy.</p>
92
  </div>
93
  """
94
 
95
  info = condition_details.get(top_condition, condition_details["Other"])
96
  return f"""
97
  <div style="font-family:Arial">
98
+ <h3>Summary</h3>
99
+ <p><b>Disease Identified:</b> {top_condition}</p>
100
+ <p><b>Cause/Status:</b> {info['description']}</p>
101
+ <p><b>Treatment/Recommendation:</b> {info['recommendation']}</p>
102
  </div>
103
  """
104
 
 
139
  with gr.TabItem("X-ray Analysis"):
140
  img_input = gr.Image(label="Upload Chest X-ray", type="pil")
141
  summary_output = gr.HTML(label="Summary Result")
142
+ gr.Button("Analyze X-ray", elem_id="analyze_button", scale=0.3).click(predict_xray, inputs=img_input, outputs=summary_output)
143
 
144
  with gr.TabItem("Report Analysis"):
145
  pdf_input = gr.File(label="Upload PDF Report", file_types=[".pdf"])
146
  summary_output_report = gr.Textbox(label="Summary Result", lines=5)
147
+ gr.Button("Analyze Report", elem_id="analyze_button", scale=0.3).click(analyze_report, inputs=pdf_input, outputs=summary_output_report)
148
 
149
  return demo
150