Rodiyah commited on
Commit
d7b38d5
·
verified ·
1 Parent(s): ba282ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -25
app.py CHANGED
@@ -11,16 +11,15 @@ from pytorch_grad_cam.utils.image import show_cam_on_image
11
  import os
12
  import datetime
13
 
14
- # -----------------------
15
  # Setup
16
- # -----------------------
17
  device = torch.device("cpu")
18
  save_dir = "/home/user/app/saved_predictions"
 
 
 
19
  os.makedirs(save_dir, exist_ok=True)
20
 
21
- # -----------------------
22
  # Load model
23
- # -----------------------
24
  model = models.resnet50(weights=None)
25
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
26
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
@@ -39,18 +38,8 @@ transform = transforms.Compose([
39
  [0.229, 0.224, 0.225])
40
  ])
41
 
42
- # -----------------------
43
  # Predict and save
44
- # -----------------------
45
  def predict_retinopathy(image):
46
- # Validate image first
47
- if not looks_like_fundus(image):
48
- return (
49
- invalid_img,
50
- "⚠️ The uploaded image does not appear to be a retinal fundus scan.\n\n"
51
- "Please upload a valid ophthalmic retinal image for Diabetic Retinopathy assessment."
52
- )
53
-
54
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
55
  img = image.convert("RGB").resize((224, 224))
56
  img_tensor = transform(img).unsqueeze(0).to(device)
@@ -66,10 +55,7 @@ def predict_retinopathy(image):
66
  # Grad-CAM
67
  rgb_img_np = np.array(img).astype(np.float32) / 255.0
68
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
69
- grayscale_cam = cam(
70
- input_tensor=img_tensor,
71
- targets=[ClassifierOutputTarget(pred)]
72
- )[0]
73
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
74
  cam_pil = Image.fromarray(cam_image)
75
 
@@ -77,14 +63,9 @@ def predict_retinopathy(image):
77
  filename = f"{timestamp}_{label}_{confidence:.2f}.png"
78
  cam_pil.save(os.path.join(save_dir, filename))
79
 
80
- return (
81
- cam_pil,
82
- f"{label} (Confidence: {confidence:.2f})"
83
- )
84
 
85
- # -----------------------
86
  # Gradio app
87
- # -----------------------
88
  gr.Interface(
89
  fn=predict_retinopathy,
90
  inputs=gr.Image(type="pil", label="Upload Retinal Image"),
@@ -101,4 +82,4 @@ gr.Interface(
101
  "⚕️ **OpthaDetect** is an AI-powered ophthalmic decision-support tool. "
102
  "It highlights retinal risk regions using Grad-CAM for better clinical interpretability."
103
  )
104
- ).launch()
 
11
  import os
12
  import datetime
13
 
 
14
  # Setup
 
15
  device = torch.device("cpu")
16
  save_dir = "/home/user/app/saved_predictions"
17
+ if not os.path.exists(save_dir):
18
+ os.makedirs(save_dir)
19
+ print("📁 Folder created:", save_dir)
20
  os.makedirs(save_dir, exist_ok=True)
21
 
 
22
  # Load model
 
23
  model = models.resnet50(weights=None)
24
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
25
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
 
38
  [0.229, 0.224, 0.225])
39
  ])
40
 
 
41
  # Predict and save
 
42
  def predict_retinopathy(image):
 
 
 
 
 
 
 
 
43
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
44
  img = image.convert("RGB").resize((224, 224))
45
  img_tensor = transform(img).unsqueeze(0).to(device)
 
55
  # Grad-CAM
56
  rgb_img_np = np.array(img).astype(np.float32) / 255.0
57
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
58
+ grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
 
 
 
59
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
60
  cam_pil = Image.fromarray(cam_image)
61
 
 
63
  filename = f"{timestamp}_{label}_{confidence:.2f}.png"
64
  cam_pil.save(os.path.join(save_dir, filename))
65
 
66
+ return cam_pil, f"{label} (Confidence: {confidence:.2f})"
 
 
 
67
 
 
68
  # Gradio app
 
69
  gr.Interface(
70
  fn=predict_retinopathy,
71
  inputs=gr.Image(type="pil", label="Upload Retinal Image"),
 
82
  "⚕️ **OpthaDetect** is an AI-powered ophthalmic decision-support tool. "
83
  "It highlights retinal risk regions using Grad-CAM for better clinical interpretability."
84
  )
85
+ ).launch()