| | from typing import Dict, List, Any |
| | from ultralytics import YOLO |
| | import os |
| | import torch |
| | import torch.nn as nn |
| | import torchvision.transforms as T |
| | from PIL import Image |
| |
|
| | class LinearClassifier(torch.nn.Module): |
| | def __init__(self, input_dim=384, output_dim=7): |
| | super(LinearClassifier, self).__init__() |
| |
|
| | self.linear = torch.nn.Linear(input_dim, output_dim) |
| | self.linear.weight.data.normal_(mean=0.0, std=0.01) |
| | self.linear.bias.data.zero_() |
| |
|
| | def forward(self, x): |
| | return self.linear(x) |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | |
| | self.dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') |
| | self.device = torch.device('cuda' if torch.cuda.is_available() else "cpu") |
| | self.dinov2_vits14.to(self.device) |
| | print('Successfully load dinov2_vits14 model') |
| | |
| | self.yolov8_model = YOLO(os.path.join(path, 'yolov8_2023-07-19_yolov8m.pt')) |
| |
|
| | self.linear_model = LinearClassifier() |
| | self.linear_model.load_state_dict(torch.load(os.path.join(path, 'linear_2023-07-18_v0.2.pt'))) |
| | self.linear_model.eval() |
| | |
| | self.transform_image = T.Compose([ |
| | T.ToTensor(), |
| | T.Resize(244), |
| | T.CenterCrop(224), |
| | T.Normalize([0.5], [0.5]) |
| | ]) |
| |
|
| | with open(os.path.join(path, 'labels.txt'), 'r') as f: |
| | self.labels = f.read().split(',') |
| |
|
| | self.name_en2vi = { |
| | "loggerhead": "Quản đồng", |
| | "green": "Vích", |
| | "leatherback": "Rùa da", |
| | "hawksbill": "Đồi mồi", |
| | "kemp_ridley": "Vích Kemp", |
| | "olive_ridley": "Đồi mồi dứa", |
| | "flatback": "Rùa lưng phẳng" |
| | } |
| | |
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | data args: |
| | inputs (:obj: `str` | `PIL.Image` | `np.array`) |
| | kwargs |
| | Return: |
| | A :obj:`list` | `dict`: will be serialized and returned |
| | """ |
| | |
| | result = self.yolov8_model(data['inputs']) |
| | |
| | img = result[0].orig_img[:,:,::-1] |
| | H, W, _ = img.shape |
| | annotated = img.copy() |
| | |
| | try: |
| | x1, y1, x2, y2 = result[0].boxes.xyxy.numpy().astype('int')[0] |
| | if result[0].boxes.conf[0].item() < 0.75: |
| | x1, y1, x2, y2 = 0, 0, W, H |
| | else: |
| | annotated = result[0].plot(labels=False, conf=False)[:,:,::-1] |
| | except: |
| | |
| | return img.tolist(), "🤔 Hmm... Mình không thấy bạn rùa nào trong bức ảnh này. Bạn hãy tải lên một bức hình khác nhé." |
| |
|
| | h, w = y2-y1, x2-x1 |
| | offset = abs(h-w) // 2 |
| | if h > w: |
| | x1 = max(x1 - offset, 0) |
| | x2 = min(x2 + offset, W) |
| | else: |
| | y1 = max(y1 - offset, 0) |
| | y2 = min(y2 + offset, H) |
| | cropped = img[y1:y2, x1:x2] |
| |
|
| | new_image = self.transform_image(Image.fromarray(cropped))[:3].unsqueeze(0) |
| | embedding = self.dinov2_vits14(new_image.to(self.device)) |
| | prediction = self.linear_model(embedding) |
| | percentage = nn.Softmax(dim=1)(prediction).detach().numpy().round(2)[0].tolist() |
| | result = {} |
| | |
| | for i in range(len(self.labels)): |
| | result[self.name_en2vi[self.labels[i]]] = percentage[i] |
| |
|
| | |
| | return annotated.tolist(), result |