kmunzwa commited on
Commit
c5ad6cb
·
verified ·
1 Parent(s): 26f6ba0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -44
app.py CHANGED
@@ -4,12 +4,8 @@ import gradio as gr
4
  # numpy is used for numerical operations
5
  import numpy as np
6
 
7
- # torch is the core PyTorch library used to run the model
8
- import torch
9
-
10
- # torchvision provides the ResNet50 architecture and image transforms
11
- import torchvision.transforms as transforms
12
- from torchvision import models
13
 
14
  # PIL is used for image loading and conversion
15
  from PIL import Image
@@ -19,40 +15,21 @@ from PIL import Image
19
  # LOAD THE MODEL
20
  # ------------------------------------
21
 
22
- # we detect whether a GPU is available and fall back to CPU if not
23
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
- print(f"Running on: {device}")
25
-
26
- # we recreate the ResNet50 architecture
27
- model = models.resnet50(weights=None)
28
-
29
- # replace the final fully connected layer to output 2 classes:
30
- # class 0 = Non-Cervix, class 1 = Cervix
31
- model.fc = torch.nn.Linear(model.fc.in_features, 2)
32
-
33
- # load the saved weights from the .pth file
34
- state_dict = torch.load("best_gatekeeper_v2.pth", map_location=device)
35
- model.load_state_dict(state_dict)
36
 
37
- # move the model to the correct device and set to evaluation mode
38
- model = model.to(device)
39
- model.eval()
40
 
41
- print("Gatekeeper model loaded successfully")
 
 
42
 
43
  # image size ResNet50 expects
44
- INPUT_SIZE = 224
45
 
46
- # standard ImageNet normalisation values used during pretraining
47
- IMAGENET_MEAN = [0.485, 0.456, 0.406]
48
- IMAGENET_STD = [0.229, 0.224, 0.225]
49
 
50
- # preprocessing pipeline
51
- preprocess = transforms.Compose([
52
- transforms.Resize((INPUT_SIZE, INPUT_SIZE)),
53
- transforms.ToTensor(),
54
- transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
55
- ])
56
 
57
  # ------------------------------------
58
  # THRESHOLDS
@@ -72,6 +49,23 @@ MIN_BRIGHTNESS = 30
72
  MIN_STD = 20
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  # ------------------------------------
76
  # CLASSIFICATION FUNCTION
77
  # ------------------------------------
@@ -82,7 +76,6 @@ def classify_image(image):
82
  return None, "Please upload an image first"
83
 
84
  # Option 3: Basic image sanity checks
85
- # run these before the model to catch obviously bad images early
86
  img_array = np.array(image)
87
 
88
  # reject images that are too dark to analyse reliably
@@ -93,17 +86,22 @@ def classify_image(image):
93
  if img_array.std() < MIN_STD:
94
  return None, "Image appears blank or uniform - please upload a real photo"
95
 
96
- # Preprocess
97
- img = Image.fromarray(image).convert("RGB")
98
- tensor = preprocess(img).unsqueeze(0).to(device)
 
 
 
 
 
99
 
100
- # Run inference
101
- with torch.no_grad():
102
- output = model(tensor)
103
- probs = torch.softmax(output, dim=1)[0]
104
 
105
- prob_non_cervix = float(probs[0])
106
- prob_cervix = float(probs[1])
 
107
 
108
  print(f"Non-Cervix: {prob_non_cervix:.4f} | Cervix: {prob_cervix:.4f}")
109
 
 
4
  # numpy is used for numerical operations
5
  import numpy as np
6
 
7
+ # ai_edge_litert is Google's official TFLite runtime
8
+ from ai_edge_litert.interpreter import Interpreter
 
 
 
 
9
 
10
  # PIL is used for image loading and conversion
11
  from PIL import Image
 
15
  # LOAD THE MODEL
16
  # ------------------------------------
17
 
18
+ # load the float32 TFLite model
19
+ interpreter = Interpreter(model_path="resnet50_float32.tflite")
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # allocate memory for the model's input and output tensors
22
+ interpreter.allocate_tensors()
 
23
 
24
+ # get input and output tensor details
25
+ input_details = interpreter.get_input_details()
26
+ output_details = interpreter.get_output_details()
27
 
28
  # image size ResNet50 expects
29
+ INPUT_SIZE = (224, 224)
30
 
31
+ print("Gatekeeper model loaded successfully")
 
 
32
 
 
 
 
 
 
 
33
 
34
  # ------------------------------------
35
  # THRESHOLDS
 
49
  MIN_STD = 20
50
 
51
 
52
+ # ------------------------------------
53
+ # IMAGE PREPROCESSING FUNCTION
54
+ # ------------------------------------
55
+
56
+ def preprocess_image(image):
57
+ # convert numpy array to PIL Image in RGB format and resize
58
+ img = Image.fromarray(image).convert("RGB").resize(INPUT_SIZE)
59
+
60
+ # convert to float32 numpy array and normalise to [0, 1]
61
+ img = np.array(img, dtype=np.float32) / 255.0
62
+
63
+ # add batch dimension: (224, 224, 3) → (1, 224, 224, 3)
64
+ img = np.expand_dims(img, axis=0)
65
+
66
+ return img
67
+
68
+
69
  # ------------------------------------
70
  # CLASSIFICATION FUNCTION
71
  # ------------------------------------
 
76
  return None, "Please upload an image first"
77
 
78
  # Option 3: Basic image sanity checks
 
79
  img_array = np.array(image)
80
 
81
  # reject images that are too dark to analyse reliably
 
86
  if img_array.std() < MIN_STD:
87
  return None, "Image appears blank or uniform - please upload a real photo"
88
 
89
+ # preprocess the image
90
+ processed = preprocess_image(image)
91
+
92
+ # load the preprocessed image into the model's input tensor
93
+ interpreter.set_tensor(input_details[0]['index'], processed)
94
+
95
+ # run inference
96
+ interpreter.invoke()
97
 
98
+ # read the output tensor
99
+ output = interpreter.get_tensor(output_details[0]['index'])
100
+ print(f"Raw model output: {output}")
 
101
 
102
+ # extract individual class probabilities
103
+ prob_non_cervix = float(output[0][0])
104
+ prob_cervix = float(output[0][1])
105
 
106
  print(f"Non-Cervix: {prob_non_cervix:.4f} | Cervix: {prob_cervix:.4f}")
107