panda1835 commited on
Commit
adbd0a7
·
1 Parent(s): e7dec10

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +46 -3
handler.py CHANGED
@@ -1,13 +1,46 @@
1
  from typing import Dict, List, Any
2
  from ultralytics import YOLO
3
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  class EndpointHandler():
6
  def __init__(self, path=""):
7
  # Preload all the elements you are going to need at inference.
 
 
 
 
 
 
 
 
 
 
8
 
9
- self.model = YOLO(os.path.join(path, 'yolov8_2023-07-19_yolov8m.pt'))
 
 
 
 
 
10
 
 
 
 
11
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
12
  """
13
  data args:
@@ -41,5 +74,15 @@ class EndpointHandler():
41
  y1 = max(y1 - offset, 0)
42
  y2 = min(y2 + offset, H)
43
  new_image = img[y1:y2, x1:x2]
44
- # Return the annotated original image with the square cropped
45
- return annotated.tolist(), new_image.tolist()
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Dict, List, Any
2
  from ultralytics import YOLO
3
  import os
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision.transforms as T
7
+ from PIL import Image
8
+
9
+ class LinearClassifier(torch.nn.Module):
10
+ def __init__(self, input_dim=384, output_dim=7):
11
+ super(LinearClassifier, self).__init__()
12
+
13
+ self.linear = torch.nn.Linear(input_dim, output_dim)
14
+ self.linear.weight.data.normal_(mean=0.0, std=0.01)
15
+ self.linear.bias.data.zero_()
16
+
17
+ def forward(self, x):
18
+ return self.linear(x)
19
 
20
  class EndpointHandler():
21
  def __init__(self, path=""):
22
  # Preload all the elements you are going to need at inference.
23
+ self.dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
24
+ device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
25
+ self.dinov2_vits14.to(device)
26
+ print('Successfully load dinov2_vits14 model')
27
+
28
+ self.yolov8_model = YOLO(os.path.join(path, 'yolov8_2023-07-19_yolov8m.pt'))
29
+
30
+ self.linear_model = LinearClassifier()
31
+ self.linear_model.load_state_dict(torch.load(os.path.join(path, 'linear_2023-07-18_v0.2.pt')))
32
+ self.linear_model.eval()
33
 
34
+ self.transform_image = T.Compose([
35
+ T.ToTensor(),
36
+ T.Resize(244),
37
+ T.CenterCrop(224),
38
+ T.Normalize([0.5], [0.5])
39
+ ])
40
 
41
+ with open(os.path.join(path, 'labels.txt'), 'r') as f:
42
+ self.labels = f.read().split(',') # loggerhead,green,leatherback...
43
+
44
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
45
  """
46
  data args:
 
74
  y1 = max(y1 - offset, 0)
75
  y2 = min(y2 + offset, H)
76
  new_image = img[y1:y2, x1:x2]
77
+
78
+ new_image = self.transform_image(Image.fromarray(cropped))[:3].unsqueeze(0)
79
+ embedding = self.dinov2_vits14(new_image.to(device))
80
+ prediction = self.linear_model(embedding)
81
+ percentage = nn.Softmax(dim=1)(prediction).detach().numpy().round(2)[0].tolist()
82
+ result = {}
83
+
84
+ for i in range(len(self.labels)):
85
+ result[name_en2vi[self.labels[i]]] = percentage[i]
86
+
87
+ # Return the annotated original image with the square cropped and result dict
88
+ return annotated.tolist(), result