ChristopherMarais commited on
Commit
28bdf2c
·
verified ·
1 Parent(s): e4ca14d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +273 -128
app.py CHANGED
@@ -2,9 +2,8 @@ import random
2
  import os
3
  import numpy as np
4
  import gradio as gr
5
- from huggingface_hub import from_pretrained_fastai
6
  from PIL import Image
7
- from groundingdino.util.inference import load_model
8
  from groundingdino.util.inference import predict as grounding_dino_predict
9
  import groundingdino.datasets.transforms as T
10
  import torch
@@ -13,9 +12,124 @@ from torchvision.transforms.functional import to_tensor
13
  from torchvision.transforms import GaussianBlur
14
  import time
15
 
16
- from Ambrosia import pre_process_image
 
 
 
 
 
 
 
 
17
 
 
 
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # Automatically set device based on availability
21
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -23,8 +137,7 @@ print(f"Using device: {DEVICE}")
23
 
24
  PROMPT = "bug"
25
 
26
-
27
- # Define a custom transform for Gaussian blur
28
  def gaussian_blur(x, p=0.5, kernel_size_min=3, kernel_size_max=20, sigma_min=0.1, sigma_max=3):
29
  if x.ndim == 4:
30
  for i in range(x.shape[0]):
@@ -34,39 +147,13 @@ def gaussian_blur(x, p=0.5, kernel_size_min=3, kernel_size_max=20, sigma_min=0.1
34
  x[i] = GaussianBlur(kernel_size=kernel_size, sigma=sigma)(x[i])
35
  return x
36
 
37
- # Custom Label Function
38
  def custom_label_func(fpath):
39
  # this directs the labels to be 2 levels up from the image folder
40
  label = fpath.parents[2].name
41
  return label
42
 
43
- # this function only describes how much a singular value in al ist stands out.
44
- # if all values in the lsit are high or low this is 1
45
- # the smaller the proportiopn of number of disimilar vlaues are to other more similar values the lower this number
46
- # the larger the gap between the dissimilar numbers and the simialr number the smaller this number
47
- # only able to interpret probabilities or values between 0 and 1
48
- # this function outputs an estimate an inverse of the classification confidence based on the probabilities of all the classes.
49
- # the wedge threshold splits the data on a threshold with a magnitude of a positive int to force a ledge/peak in the data
50
- def unkown_prob_calc(probs, wedge_threshold, wedge_magnitude=1, wedge='strict'):
51
- if wedge =='strict':
52
- increase_var = (1/(wedge_magnitude))
53
- decrease_var = (wedge_magnitude)
54
- if wedge =='dynamic': # this allows pointsthat are furhter from the threshold ot be moved less and points clsoer to be moved more
55
- increase_var = (1/(wedge_magnitude*((1-np.abs(probs-wedge_threshold)))))
56
- decrease_var = (wedge_magnitude*((1-np.abs(probs-wedge_threshold))))
57
- else:
58
- print("Error: use 'strict' (default) or 'dynamic' as options for the wedge parameter!")
59
- probs = np.where(probs>=wedge_threshold , probs**increase_var, probs)
60
- probs = np.where(probs<=wedge_threshold , probs**decrease_var, probs)
61
- diff_matrix = np.abs(probs[:, np.newaxis] - probs)
62
- diff_matrix_sum = np.sum(diff_matrix)
63
- probs_sum = np.sum(probs)
64
- class_val = (diff_matrix_sum/probs_sum)
65
- max_class_val = ((len(probs)-1)*2)
66
- kown_prob = class_val/max_class_val
67
- unknown_prob = 1-kown_prob
68
- return(unknown_prob)
69
-
70
  def load_image(image_source):
71
  transform = T.Compose(
72
  [
@@ -80,17 +167,47 @@ def load_image(image_source):
80
  image_transformed, _ = transform(image_source, None)
81
  return image_transformed
82
 
83
- # load object detection model
84
- od_model = load_model(
85
  model_checkpoint_path="groundingdino_swint_ogc.pth",
86
  model_config_path="GroundingDINO_SwinT_OGC.cfg.py",
87
  device=DEVICE)
88
  print("Object detection model loaded")
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def detect_objects(og_image, model=od_model, prompt="bug . insect", device="cpu"):
91
  TEXT_PROMPT = prompt
92
- BOX_TRESHOLD = 0.35
93
- TEXT_TRESHOLD = 0.25
94
  DEVICE = device # cuda or cpu
95
 
96
  # Convert numpy array to PIL Image if needed
@@ -102,18 +219,18 @@ def detect_objects(og_image, model=od_model, prompt="bug . insect", device="cpu"
102
  # Transform the image
103
  image_transformed = load_image(image_source = og_image_obj)
104
 
105
- # Your model prediction code here...
106
  boxes, logits, phrases = grounding_dino_predict(
107
  model=model,
108
  image=image_transformed,
109
  caption=TEXT_PROMPT,
110
- box_threshold=BOX_TRESHOLD,
111
- text_threshold=TEXT_TRESHOLD,
112
  device=DEVICE)
113
 
114
  # Use og_image_obj directly for further processing
115
- height, width = og_image_obj.size
116
- boxes_norm = boxes * torch.Tensor([height, width, height, width])
117
  xyxy = box_convert(
118
  boxes=boxes_norm,
119
  in_fmt="cxcywh",
@@ -122,111 +239,139 @@ def detect_objects(og_image, model=od_model, prompt="bug . insect", device="cpu"
122
  for i in range(len(boxes_norm)):
123
  crop_img = og_image_obj.crop((xyxy[i]))
124
  img_lst.append(crop_img)
125
- return (img_lst)
126
-
127
-
128
- # load beetle classifier model
129
- repo_id="ChristopherMarais/beetle-model-mini"
130
- bc_model = from_pretrained_fastai(repo_id)
131
- bc_model.to(DEVICE)
132
- # get class names
133
- labels = np.append(np.array(bc_model.dls.vocab), "Unknown")
134
- # Replace some names used in the classifier
135
- # Check if the element was found to prevent errors
136
- # The target value you're looking for
137
- target = "Scolotodes_schwarzi"
138
- # Finding the index using np.where
139
- indices = np.where(labels == target)
140
- # Extracting the first occurrence, if found
141
- if indices[0].size > 0:
142
- idx = indices[0][0]
143
- print(f"Index of {target}: {idx}")
144
- else:
145
- print(f"{target} not found in the array.")
146
- # Replace occurence
147
- if idx != -1:
148
- labels[idx] = "Scolytodes_glaber"
149
- print("Classification model loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
 
151
  def predict_beetle(img):
152
- print("Detecting & classifying beetles...")
153
- start_time = time.perf_counter() # record how long it processes
154
- # Split image into smaller images of detected objects
 
155
  image_lst = detect_objects(og_image=img, model=od_model, prompt=PROMPT, device=DEVICE)
156
 
157
- # pre_process = pre_process_image(manual_thresh_buffer=0.15, image = img) # use image_dir if directory of image used
158
- # pre_process.segment(cluster_num=2,
159
- # image_edge_buffer=50)
160
- # image_lst = pre_process.col_image_lst
161
 
162
- print("Objects detected")
163
- end_time = time.perf_counter()
164
- processing_time = end_time - start_time
165
- print(f"Processing duration: {processing_time} seconds")
166
- # get predictions for all segments
167
- conf_dict_lst = []
168
  output_lst = []
169
  img_cnt = len(image_lst)
170
- for i in range(0,img_cnt):
171
- prob_ar = np.array(bc_model.predict(image_lst[i])[2].to(DEVICE).cpu())
172
- unkown_prob = unkown_prob_calc(probs=prob_ar, wedge_threshold=0.85, wedge_magnitude=5, wedge='dynamic')
173
- prob_ar = np.append(prob_ar, unkown_prob)
174
- prob_ar = np.around(prob_ar*100, decimals=1)
175
- # only show the top 5 predictions
176
- # Sorting the dictionary by value in descending order and taking the top items
177
- top_num = 3
178
- conf_dict = {labels[i]: float(prob_ar[i]) for i in range(len(prob_ar))}
179
- print(conf_dict)
180
- conf_dict = dict(sorted(conf_dict.items(), key=lambda item: item[1], reverse=True)[:top_num])
181
- conf_dict_lst.append(str(conf_dict)[1:-1]) # remove dictionary brackets
182
- result = list(zip(image_lst, conf_dict_lst))
183
- print(f"Beetle classified - {i}")
184
- # record how long classification takes
185
- end_time = time.perf_counter()
186
- processing_time = end_time - start_time
187
- print(f"Processing duration: {processing_time} seconds")
188
- return(result)
189
-
190
-
191
- # gradio app
192
- # css = """
193
- # button {
194
- # width: auto; /* Set your desired width */
195
- # }
196
- # """
197
 
198
  sample_images_dir = "example_images"
199
 
200
  # Sample images with labels
201
  example_images = [
202
- [os.path.join(sample_images_dir, "example1.jpg")],
203
- [os.path.join(sample_images_dir, "example2.jpg")],
204
- [os.path.join(sample_images_dir, "example3.jpg")],
205
- [os.path.join(sample_images_dir, "mixed.jpg")]
206
-
207
  ]
208
  # Corresponding labels for the example images
209
- example_labels = ["Example Beetle 1", "Example Beetle 2", "Example Beetle 3", "Example Beetles Mixed"]
210
 
211
- with gr.Blocks() as demo: # css=css in the brackets
212
  gr.Markdown("<h1><center>Bark Beetle Classification</center></h1>")
213
- # gr.Markdown("<h3><center>Note this instance of the classifier is for demonstration only and runs on CPU, not on GPU. If you are interested in testing the model, contact us, and we will switch it to its full capacity in an instant.<h3><center>")
214
  with gr.Column(variant="panel"):
215
  with gr.Row(variant="compact"):
216
  inputs = gr.Image(label="Input Image")
217
- # Use the `full_width` parameter directly
218
- btn = gr.Button("Classify")
 
 
 
 
 
 
 
 
219
 
220
  # Set the gallery layout and height directly in the constructor
221
- gallery = gr.Gallery(label="Show images", show_label=True, elem_id="gallery", columns=8, height="auto")
222
-
223
- # Add examples with labels
224
- gr.Examples(
225
- examples=example_images,
226
- inputs=inputs,
227
- examples_per_page=4,
228
- example_labels=example_labels
229
- )
230
 
231
- btn.click(predict_beetle, inputs, gallery)
232
- demo.launch(debug=True, show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import numpy as np
4
  import gradio as gr
 
5
  from PIL import Image
6
+ from groundingdino.util.inference import load_model as load_groundingdino_model
7
  from groundingdino.util.inference import predict as grounding_dino_predict
8
  import groundingdino.datasets.transforms as T
9
  import torch
 
12
  from torchvision.transforms import GaussianBlur
13
  import time
14
 
15
+ # ----------------------------
16
+ # DINOv2 Classifier Imports
17
+ # ----------------------------
18
+ import torch.nn as nn
19
+ from torchvision import transforms
20
+ import pandas as pd
21
+ from typing import List, Tuple
22
+ import copy
23
+ import matplotlib.pyplot as plt
24
 
25
+ # ----------------------------
26
+ # DINOv2 Classifier Definitions
27
+ # ----------------------------
28
 
29
+ # 1. PadToSquare Class
30
+ class PadToSquare:
31
+ """
32
+ Pads an image to make it square by adding padding to the shorter side.
33
+ """
34
+ def __init__(self, fill=0):
35
+ self.fill = fill
36
+
37
+ def __call__(self, img):
38
+ w, h = img.size
39
+ max_wh = max(w, h)
40
+ hp = (max_wh - w) // 2
41
+ vp = (max_wh - h) // 2
42
+ padding = (hp, vp, max_wh - w - hp, max_wh - h - vp)
43
+ return transforms.functional.pad(img, padding, fill=self.fill, padding_mode='constant')
44
+
45
+ # 2. DinoVisionTransformerClassifier Class (Modified to include entropy-based approach)
46
+ class DinoVisionTransformerClassifier(nn.Module):
47
+ """
48
+ DINOv2 Vision Transformer-based classifier with entropy-based "Unknown" class handling.
49
+ """
50
+ def __init__(self, num_classes, hidden_size=256, dropout_p=0.5, negative_slope=0.01):
51
+ super(DinoVisionTransformerClassifier, self).__init__()
52
+ # Load DINOv2 model from torch.hub
53
+ self.transformer = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14', pretrained=True)
54
+ self.transformer.norm = nn.Identity() # Remove existing normalization if necessary
55
+
56
+ # Batch Normalization after transformer
57
+ self.batch_norm1 = nn.BatchNorm1d(384) # 384 is the embedding size
58
+
59
+ # Classification head
60
+ self.classifier = nn.Sequential(
61
+ nn.Linear(384, hidden_size),
62
+ nn.BatchNorm1d(hidden_size),
63
+ nn.LeakyReLU(negative_slope=negative_slope, inplace=True),
64
+ nn.Dropout(p=dropout_p),
65
+ nn.Linear(hidden_size, num_classes)
66
+ )
67
+
68
+ # Initialize weights
69
+ self._initialize_weights()
70
+
71
+ def forward(self, x):
72
+ features = self.transformer(x) # Forward pass through the transformer
73
+ features = self.batch_norm1(features) # Apply Batch Normalization
74
+ logits = self.classifier(features) # Forward pass through the classification head
75
+ return logits, features # Return both logits and features
76
+
77
+ def _initialize_weights(self):
78
+ # Initialize weights of the classifier layers
79
+ for m in self.classifier.modules():
80
+ if isinstance(m, nn.Linear):
81
+ nn.init.kaiming_normal_(m.weight, a=0.01, mode='fan_in', nonlinearity='leaky_relu')
82
+ if m.bias is not None:
83
+ nn.init.zeros_(m.bias)
84
+ elif isinstance(m, nn.BatchNorm1d):
85
+ nn.init.ones_(m.weight)
86
+ nn.init.zeros_(m.bias)
87
+
88
+ # 3. Model Loading Function (Updated for Entropy-Based Classifier)
89
+ def load_model(model_path, device):
90
+ """
91
+ Loads the trained model and class information from the saved checkpoint.
92
+
93
+ Args:
94
+ model_path (str): Path to the saved .pth model file.
95
+ device (torch.device): Device to load the model onto.
96
+
97
+ Returns:
98
+ model (nn.Module): The loaded PyTorch model.
99
+ class_names (List[str]): List of class names.
100
+ """
101
+ if not os.path.exists(model_path):
102
+ raise FileNotFoundError(f"Model file '{model_path}' does not exist.")
103
+
104
+ checkpoint = torch.load(model_path, map_location=device)
105
+ class_names = checkpoint['class_names']
106
+ num_classes = len(class_names)
107
+
108
+ # Initialize the model architecture
109
+ model = DinoVisionTransformerClassifier(num_classes=num_classes)
110
+ model.load_state_dict(checkpoint['model_state_dict'])
111
+ model.to(device)
112
+ model.eval() # Set to evaluation mode
113
+
114
+ return model, class_names
115
+
116
+ # 4. Image Preprocessing Function (Updated to accept PIL Image directly)
117
+ def preprocess_image_pil(pil_image: Image.Image, transform: transforms.Compose) -> torch.Tensor:
118
+ """
119
+ Applies the transformation pipeline to a PIL image.
120
+
121
+ Args:
122
+ pil_image (PIL.Image.Image): The image to preprocess.
123
+ transform (transforms.Compose): The transformation pipeline.
124
+
125
+ Returns:
126
+ torch.Tensor: The preprocessed image tensor.
127
+ """
128
+ return transform(pil_image)
129
+
130
+ # ----------------------------
131
+ # Gradio App Definitions
132
+ # ----------------------------
133
 
134
  # Automatically set device based on availability
135
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
137
 
138
  PROMPT = "bug"
139
 
140
+ # Define a custom transform for Gaussian blur (Unused in current context)
 
141
  def gaussian_blur(x, p=0.5, kernel_size_min=3, kernel_size_max=20, sigma_min=0.1, sigma_max=3):
142
  if x.ndim == 4:
143
  for i in range(x.shape[0]):
 
147
  x[i] = GaussianBlur(kernel_size=kernel_size, sigma=sigma)(x[i])
148
  return x
149
 
150
+ # Custom Label Function (Unused in current context)
151
  def custom_label_func(fpath):
152
  # this directs the labels to be 2 levels up from the image folder
153
  label = fpath.parents[2].name
154
  return label
155
 
156
+ # Image loading function for GroundingDINO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def load_image(image_source):
158
  transform = T.Compose(
159
  [
 
167
  image_transformed, _ = transform(image_source, None)
168
  return image_transformed
169
 
170
+ # Load GroundingDINO object detection model
171
+ od_model = load_groundingdino_model(
172
  model_checkpoint_path="groundingdino_swint_ogc.pth",
173
  model_config_path="GroundingDINO_SwinT_OGC.cfg.py",
174
  device=DEVICE)
175
  print("Object detection model loaded")
176
 
177
+ # Load DINOv2 classifier model (Updated to use the entropy-based classifier)
178
+ # Update MODEL_PATH to the path where your DINOv2 model checkpoint is stored
179
+ MODEL_PATH = 'dinov2_classifier_with_vos_unsure.pth' # Updated model path
180
+ dinov2_model, class_names = load_model(MODEL_PATH, torch.device(DEVICE))
181
+ print(f"DINOv2 Classification model loaded with {len(class_names)} classes.")
182
+
183
+ # Optionally, append "Unknown" to class names if needed
184
+ # Removed the line that appends "Unknown" as the model handles it via thresholding
185
+
186
+ # Replace specific class names if necessary
187
+ # Example: Replace "Scolotodes_schwarzi" with "Scolytodes_glaber"
188
+ target = "Scolotodes_schwarzi"
189
+ if target in class_names:
190
+ idx = class_names.index(target)
191
+ class_names[idx] = "Scolytodes_glaber"
192
+ print(f"Replaced '{target}' with 'Scolytodes_glaber' in class names.")
193
+ else:
194
+ print(f"'{target}' not found in class names. No replacement made.")
195
+
196
+ # Define the transformation pipeline for DINOv2 model
197
+ dinov2_transform = transforms.Compose([
198
+ transforms.Resize(224), # Resize smaller edge to 224
199
+ PadToSquare(), # Pad to make the image square
200
+ transforms.Resize((224, 224)), # Resize to 224x224
201
+ transforms.ToTensor(),
202
+ transforms.Normalize([0.485, 0.456, 0.406], # Normalize with ImageNet mean
203
+ [0.229, 0.224, 0.225]) # Normalize with ImageNet std
204
+ ])
205
+
206
+ # Object Detection Function
207
  def detect_objects(og_image, model=od_model, prompt="bug . insect", device="cpu"):
208
  TEXT_PROMPT = prompt
209
+ BOX_THRESHOLD = 0.35 # Adjusted back to original value
210
+ TEXT_THRESHOLD = 0.25 # Adjusted back to original value
211
  DEVICE = device # cuda or cpu
212
 
213
  # Convert numpy array to PIL Image if needed
 
219
  # Transform the image
220
  image_transformed = load_image(image_source = og_image_obj)
221
 
222
+ # Model prediction
223
  boxes, logits, phrases = grounding_dino_predict(
224
  model=model,
225
  image=image_transformed,
226
  caption=TEXT_PROMPT,
227
+ box_threshold=BOX_THRESHOLD,
228
+ text_threshold=TEXT_THRESHOLD,
229
  device=DEVICE)
230
 
231
  # Use og_image_obj directly for further processing
232
+ width, height = og_image_obj.size # Corrected to (width, height)
233
+ boxes_norm = boxes * torch.Tensor([width, height, width, height])
234
  xyxy = box_convert(
235
  boxes=boxes_norm,
236
  in_fmt="cxcywh",
 
239
  for i in range(len(boxes_norm)):
240
  crop_img = og_image_obj.crop((xyxy[i]))
241
  img_lst.append(crop_img)
242
+ print(f"Detected {len(img_lst)} objects.")
243
+ return img_lst
244
+
245
+ # Inference/Class Prediction Function using the Entropy-Based DINOv2 Classifier
246
+ def classify_beetle(img: Image.Image, threshold=75.0):
247
+ """
248
+ Classifies the input image using the DINOv2 classifier with entropy-based "Unknown" class.
249
+
250
+ Args:
251
+ img (PIL.Image.Image): The image to classify.
252
+ threshold (float): Confidence threshold to assign "Unknown".
253
+
254
+ Returns:
255
+ dict: The top 3 class labels with their corresponding confidence scores and "Unknown" if applicable.
256
+ """
257
+ # Preprocess the image
258
+ input_tensor = preprocess_image_pil(img, dinov2_transform).unsqueeze(0).to(torch.device(DEVICE))
259
+ print(f"Input tensor shape: {input_tensor.shape}")
260
+
261
+ with torch.no_grad():
262
+ outputs, _ = dinov2_model(input_tensor)
263
+ print(f"Model outputs: {outputs}")
264
+ probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] # p(x) in [0,1]
265
+ print(f"Probabilities (0-1 scale): {probabilities}")
266
+
267
+ # Calculate entropy
268
+ # Adding a small epsilon to avoid log(0)
269
+ epsilon = 1e-12
270
+ entropy = -np.sum(probabilities * np.log(probabilities + epsilon))
271
+ # Maximum entropy for uniform distribution
272
+ max_entropy = -np.sum((1.0 / len(probabilities)) * np.log(1.0 / len(probabilities)))
273
+ normalized_entropy = entropy / max_entropy # Normalize between 0 and 1
274
+ unknown_prob = normalized_entropy
275
+ print(f"Entropy: {entropy}, Normalized Entropy: {normalized_entropy}, Unknown Probability: {unknown_prob}")
276
+
277
+ # Convert probabilities to percentage for display
278
+ probabilities_percent = np.around(probabilities * 100, decimals=1)
279
+ print(f"Probabilities (Percentage): {probabilities_percent}")
280
+
281
+ # Get top 3 classes
282
+ top_indices = np.argsort(probabilities_percent)[-3:][::-1] # Indices of top 3 classes
283
+ top_probs = probabilities_percent[top_indices]
284
+ top_classes = [class_names[i] for i in top_indices]
285
+
286
+ # Initialize conf_dict with top 3 classes
287
+ conf_dict = {top_classes[i]: float(top_probs[i]) for i in range(len(top_classes))}
288
+
289
+ # Assign "Unknown" based on entropy and threshold
290
+ if top_probs[0] < threshold:
291
+ conf_dict["Unknown"] = float(np.around(unknown_prob, decimals=1))
292
+
293
+ print(f"Conf_dict: {conf_dict}")
294
+
295
+ return conf_dict
296
 
297
+ # Main Prediction Function for Gradio
298
  def predict_beetle(img):
299
+ print("Detecting objects in the image...")
300
+ start_time = time.perf_counter() # Start timing
301
+
302
+ # Detect objects in the image
303
  image_lst = detect_objects(og_image=img, model=od_model, prompt=PROMPT, device=DEVICE)
304
 
305
+ print(f"Detected {len(image_lst)} objects.")
 
 
 
306
 
307
+ # Initialize lists to hold results
 
 
 
 
 
308
  output_lst = []
309
  img_cnt = len(image_lst)
310
+
311
+ for i in range(img_cnt):
312
+ print(f"Classifying object {i+1}/{img_cnt}...")
313
+ conf_dict = classify_beetle(image_lst[i])
314
+ output_lst.append([image_lst[i], conf_dict])
315
+ print(f"Object {i+1} classified.")
316
+
317
+ end_time = time.perf_counter()
318
+ processing_time = end_time - start_time
319
+ print(f"Total processing duration: {processing_time:.2f} seconds")
320
+
321
+ return output_lst
322
+
323
+ # ----------------------------
324
+ # Gradio Interface Setup
325
+ # ----------------------------
 
 
 
 
 
 
 
 
 
 
 
326
 
327
  sample_images_dir = "example_images"
328
 
329
  # Sample images with labels
330
  example_images = [
331
+ os.path.join(sample_images_dir, "example1.jpg"),
332
+ os.path.join(sample_images_dir, "example2.jpg"),
333
+ os.path.join(sample_images_dir, "example3.jpg"),
334
+ os.path.join(sample_images_dir, "mixed.jpg")
 
335
  ]
336
  # Corresponding labels for the example images
337
+ example_labels = ["Example Beetles 1", "Example Beetles 2", "Example Beetles 3", "Example Beetles 4"]
338
 
339
+ with gr.Blocks() as demo:
340
  gr.Markdown("<h1><center>Bark Beetle Classification</center></h1>")
341
+
342
  with gr.Column(variant="panel"):
343
  with gr.Row(variant="compact"):
344
  inputs = gr.Image(label="Input Image")
345
+ # Add examples with labels
346
+ gr.Examples(
347
+ label="Select an example below if you have no images to upload.",
348
+ examples=example_images,
349
+ inputs=inputs,
350
+ examples_per_page=4,
351
+ example_labels=example_labels
352
+ )
353
+
354
+ btn = gr.Button("Classify", variant="primary")
355
 
356
  # Set the gallery layout and height directly in the constructor
357
+ gallery = gr.Gallery(label="Classified Objects", show_label=True, elem_id="gallery", columns=4, height="auto")
 
 
 
 
 
 
 
 
358
 
359
+ # Define the output format for the gallery
360
+ def format_gallery(results):
361
+ formatted = []
362
+ for img, conf in results:
363
+ # Create a label string from the confidence dictionary
364
+ label_str = ", ".join([f"{k}: {v:.1f}%" for k, v in conf.items()])
365
+ # Append the image and label as a tuple
366
+ formatted.append((img, label_str))
367
+ return formatted
368
+
369
+ # Modify the click event to format the gallery
370
+ btn.click(
371
+ lambda img: format_gallery(predict_beetle(img)),
372
+ inputs,
373
+ gallery
374
+ )
375
+
376
+ # Launch the Gradio app
377
+ demo.launch(share=True, inline=True, debug=True, show_error=True)