kmunzwa commited on
Commit
92b3bf5
·
verified ·
1 Parent(s): b744d5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -28
app.py CHANGED
@@ -20,48 +20,57 @@ from PIL import Image
20
  # ------------------------------------
21
 
22
  # we detect whether a GPU is available and fall back to CPU if not
23
- # hugging face free tier runs on CPU so this will almost always be cpu
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  print(f"Running on: {device}")
26
 
27
  # we recreate the ResNet50 architecture
28
- # weights=None because we will load our own trained weights below
29
  model = models.resnet50(weights=None)
30
 
31
- # the original ResNet50 outputs 1000 classes (ImageNet)
32
- # we replace the final fully connected layer to output 2 classes:
33
  # class 0 = Non-Cervix, class 1 = Cervix
34
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
35
 
36
- # we load the saved weights from the .pth file
37
- # map_location=device ensures it loads correctly even without a GPU
38
  state_dict = torch.load("best_gatekeeper_v2.pth", map_location=device)
39
  model.load_state_dict(state_dict)
40
 
41
- # we move the model to the correct device (CPU or GPU)
42
  model = model.to(device)
43
-
44
- # we set the model to evaluation mode
45
- # this disables dropout and batch normalisation training behaviour
46
  model.eval()
47
 
48
  print("Gatekeeper model loaded successfully")
49
 
50
- # this is the image size ResNet50 expects
51
  INPUT_SIZE = 224
52
 
53
- # these are the standard ImageNet normalisation values
54
- # ResNet50 was pretrained on ImageNet so we use the same values
55
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
56
  IMAGENET_STD = [0.229, 0.224, 0.225]
57
 
58
- # we define the preprocessing pipeline using torchvision transforms
59
  preprocess = transforms.Compose([
60
  transforms.Resize((INPUT_SIZE, INPUT_SIZE)),
61
- transforms.ToTensor(), # converts [0,255] → [0,1]
62
  transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
63
  ])
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # ------------------------------------
67
  # CLASSIFICATION FUNCTION
@@ -72,32 +81,43 @@ def classify_image(image):
72
  if image is None:
73
  return None, "Please upload an image first"
74
 
75
- # convert the numpy array from gradio to a PIL Image in RGB format
76
- img = Image.fromarray(image).convert("RGB")
 
77
 
78
- # apply the preprocessing pipeline and add a batch dimension
79
- # unsqueeze(0) changes shape from (3, 224, 224) to (1, 3, 224, 224)
 
 
 
 
 
 
 
 
80
  tensor = preprocess(img).unsqueeze(0).to(device)
81
 
82
- # run inference without computing gradients (saves memory and is faster)
83
  with torch.no_grad():
84
- output = model(tensor) # raw logits shape: (1, 2)
85
- probs = torch.softmax(output, dim=1)[0] # convert to probabilities
86
 
87
- # extract individual class probabilities as plain Python floats
88
  prob_non_cervix = float(probs[0])
89
  prob_cervix = float(probs[1])
90
 
91
  print(f"Non-Cervix: {prob_non_cervix:.4f} | Cervix: {prob_cervix:.4f}")
92
 
93
- # determine the final prediction label
94
- # cervix must score at least 0.55 to be accepted as a positive detection
95
- if prob_cervix >= 0.55:
 
96
  prediction_text = "Cervix Detected"
97
- else:
98
  prediction_text = "Non-Cervix"
 
 
 
99
 
100
- # build a dictionary for gradio's Label component (displays as bar chart)
101
  scores = {
102
  "Cervix": round(prob_cervix, 4),
103
  "Non-Cervix": round(prob_non_cervix, 4),
@@ -155,6 +175,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
155
  | 0 | Non-Cervix | Image does NOT contain cervix |
156
  | 1 | Cervix | Image contains cervix |
157
 
 
 
 
 
 
 
158
  ---
159
  Disclaimer: This tool is for research purposes only.
160
  It is not intended for clinical diagnosis or medical use.
 
20
  # ------------------------------------
21
 
22
  # we detect whether a GPU is available and fall back to CPU if not
 
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  print(f"Running on: {device}")
25
 
26
  # we recreate the ResNet50 architecture
 
27
  model = models.resnet50(weights=None)
28
 
29
+ # replace the final fully connected layer to output 2 classes:
 
30
  # class 0 = Non-Cervix, class 1 = Cervix
31
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
32
 
33
+ # load the saved weights from the .pth file
 
34
  state_dict = torch.load("best_gatekeeper_v2.pth", map_location=device)
35
  model.load_state_dict(state_dict)
36
 
37
+ # move the model to the correct device and set to evaluation mode
38
  model = model.to(device)
 
 
 
39
  model.eval()
40
 
41
  print("Gatekeeper model loaded successfully")
42
 
43
+ # image size ResNet50 expects
44
  INPUT_SIZE = 224
45
 
46
+ # standard ImageNet normalisation values used during pretraining
 
47
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
48
  IMAGENET_STD = [0.229, 0.224, 0.225]
49
 
50
+ # preprocessing pipeline
51
  preprocess = transforms.Compose([
52
  transforms.Resize((INPUT_SIZE, INPUT_SIZE)),
53
+ transforms.ToTensor(),
54
  transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
55
  ])
56
 
57
+ # ------------------------------------
58
+ # THRESHOLDS
59
+ # ------------------------------------
60
+
61
+ # minimum probability for cervix to be accepted as a positive detection
62
+ CERVIX_THRESHOLD = 0.55
63
+
64
+ # minimum gap between cervix and non-cervix probabilities
65
+ # if the gap is smaller than this the prediction is too uncertain to trust
66
+ CONFIDENCE_GAP = 0.15
67
+
68
+ # minimum image brightness - images below this are too dark to classify
69
+ MIN_BRIGHTNESS = 30
70
+
71
+ # minimum image contrast - images below this are blank or uniform
72
+ MIN_STD = 20
73
+
74
 
75
  # ------------------------------------
76
  # CLASSIFICATION FUNCTION
 
81
  if image is None:
82
  return None, "Please upload an image first"
83
 
84
+ # Option 3: Basic image sanity checks
85
+ # run these before the model to catch obviously bad images early
86
+ img_array = np.array(image)
87
 
88
+ # reject images that are too dark to analyse reliably
89
+ if img_array.mean() < MIN_BRIGHTNESS:
90
+ return None, "Image is too dark - please upload a clearer photo"
91
+
92
+ # reject images that are blank, uniformly coloured, or plain screenshots
93
+ if img_array.std() < MIN_STD:
94
+ return None, "Image appears blank or uniform - please upload a real photo"
95
+
96
+ # Preprocess
97
+ img = Image.fromarray(image).convert("RGB")
98
  tensor = preprocess(img).unsqueeze(0).to(device)
99
 
100
+ # Run inference
101
  with torch.no_grad():
102
+ output = model(tensor)
103
+ probs = torch.softmax(output, dim=1)[0]
104
 
 
105
  prob_non_cervix = float(probs[0])
106
  prob_cervix = float(probs[1])
107
 
108
  print(f"Non-Cervix: {prob_non_cervix:.4f} | Cervix: {prob_cervix:.4f}")
109
 
110
+ # Option 1: Confidence threshold + gap check
111
+ gap = prob_cervix - prob_non_cervix
112
+
113
+ if prob_cervix >= CERVIX_THRESHOLD and gap >= CONFIDENCE_GAP:
114
  prediction_text = "Cervix Detected"
115
+ elif prob_non_cervix >= CERVIX_THRESHOLD and gap <= -CONFIDENCE_GAP:
116
  prediction_text = "Non-Cervix"
117
+ else:
118
+ # not confident enough either way - temporary misclassification safety net
119
+ prediction_text = "Uncertain - please retake or upload a clearer image"
120
 
 
121
  scores = {
122
  "Cervix": round(prob_cervix, 4),
123
  "Non-Cervix": round(prob_non_cervix, 4),
 
175
  | 0 | Non-Cervix | Image does NOT contain cervix |
176
  | 1 | Cervix | Image contains cervix |
177
 
178
+ ---
179
+ **How predictions work:**
180
+ - **Cervix Detected** - model scored >= 0.55 with a gap of >= 0.15 over Non-Cervix
181
+ - **Non-Cervix** - model scored >= 0.55 with a gap of >= 0.15 over Cervix
182
+ - **Uncertain** - model was not confident enough; retake the image
183
+
184
  ---
185
  Disclaimer: This tool is for research purposes only.
186
  It is not intended for clinical diagnosis or medical use.