ChristopherMarais commited on
Commit
315a1bb
·
1 Parent(s): 6a04cfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -10,11 +10,14 @@ import torch
10
  from torchvision.ops import box_convert
11
  from torchvision.transforms.functional import to_tensor
12
  from torchvision.transforms import GaussianBlur
 
13
 
14
  from Ambrosia import pre_process_image
15
- import time
16
 
17
 
 
 
 
18
  # Define a custom transform for Gaussian blur
19
  def gaussian_blur(x, p=0.5, kernel_size_min=3, kernel_size_max=20, sigma_min=0.1, sigma_max=3):
20
  if x.ndim == 4:
@@ -75,7 +78,7 @@ def load_image(image_source):
75
  od_model = load_model(
76
  model_checkpoint_path="groundingdino_swint_ogc.pth",
77
  model_config_path="GroundingDINO_SwinT_OGC.cfg.py",
78
- device="cpu")
79
  print("Object detection model loaded")
80
 
81
  def detect_objects(og_image, model=od_model, prompt="bug . insect", device="cpu"):
@@ -119,6 +122,7 @@ def detect_objects(og_image, model=od_model, prompt="bug . insect", device="cpu"
119
  # load beetle classifier model
120
  repo_id="ChristopherMarais/beetle-model-mini"
121
  bc_model = from_pretrained_fastai(repo_id)
 
122
  # get class names
123
  labels = np.append(np.array(bc_model.dls.vocab), "Unknown")
124
  print("Classification model loaded")
@@ -127,7 +131,7 @@ def predict_beetle(img):
127
  print("Detecting & classifying beetles...")
128
  start_time = time.perf_counter() # record how long it processes
129
  # Split image into smaller images of detected objects
130
- image_lst = detect_objects(og_image=img, model=od_model, prompt="bug . insect", device="cpu")
131
 
132
  # pre_process = pre_process_image(manual_thresh_buffer=0.15, image = img) # use image_dir if directory of image used
133
  # pre_process.segment(cluster_num=2,
@@ -143,7 +147,7 @@ def predict_beetle(img):
143
  output_lst = []
144
  img_cnt = len(image_lst)
145
  for i in range(0,img_cnt):
146
- prob_ar = np.array(bc_model.predict(image_lst[i])[2])
147
  unkown_prob = unkown_prob_calc(probs=prob_ar, wedge_threshold=0.85, wedge_magnitude=5, wedge='dynamic')
148
  prob_ar = np.append(prob_ar, unkown_prob)
149
  prob_ar = np.around(prob_ar*100, decimals=1)
 
10
  from torchvision.ops import box_convert
11
  from torchvision.transforms.functional import to_tensor
12
  from torchvision.transforms import GaussianBlur
13
+ import time
14
 
15
  from Ambrosia import pre_process_image
 
16
 
17
 
18
+
19
+ DEVICE = "cuda" # cpu or cuda
20
+
21
  # Define a custom transform for Gaussian blur
22
  def gaussian_blur(x, p=0.5, kernel_size_min=3, kernel_size_max=20, sigma_min=0.1, sigma_max=3):
23
  if x.ndim == 4:
 
78
  od_model = load_model(
79
  model_checkpoint_path="groundingdino_swint_ogc.pth",
80
  model_config_path="GroundingDINO_SwinT_OGC.cfg.py",
81
+ device=DEVICE)
82
  print("Object detection model loaded")
83
 
84
  def detect_objects(og_image, model=od_model, prompt="bug . insect", device="cpu"):
 
122
  # load beetle classifier model
123
  repo_id="ChristopherMarais/beetle-model-mini"
124
  bc_model = from_pretrained_fastai(repo_id)
125
+ bc_model.to(DEVICE)
126
  # get class names
127
  labels = np.append(np.array(bc_model.dls.vocab), "Unknown")
128
  print("Classification model loaded")
 
131
  print("Detecting & classifying beetles...")
132
  start_time = time.perf_counter() # record how long it processes
133
  # Split image into smaller images of detected objects
134
+ image_lst = detect_objects(og_image=img, model=od_model, prompt="bug . insect", device=DEVICE)
135
 
136
  # pre_process = pre_process_image(manual_thresh_buffer=0.15, image = img) # use image_dir if directory of image used
137
  # pre_process.segment(cluster_num=2,
 
147
  output_lst = []
148
  img_cnt = len(image_lst)
149
  for i in range(0,img_cnt):
150
+ prob_ar = np.array(bc_model.predict(image_lst[i])[2].to(DEVICE).cpu())
151
  unkown_prob = unkown_prob_calc(probs=prob_ar, wedge_threshold=0.85, wedge_magnitude=5, wedge='dynamic')
152
  prob_ar = np.append(prob_ar, unkown_prob)
153
  prob_ar = np.around(prob_ar*100, decimals=1)