Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| import torch | |
| from model.detector import * | |
| from model.backbone import * | |
| from model.data import Therin | |
| import datetime | |
| from model.detector.fasterRCNN import FasterRCNN | |
| from model.backbone.densenet import DenseNet | |
| from model.utils.engine import * | |
| from torchvision.models.detection.backbone_utils import resnet_fpn_backbone, _resnet_fpn_extractor | |
| from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_V2_Weights | |
| from torchvision import transforms as T | |
| from PIL import Image, ImageDraw | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def label_to_text_en(l): | |
| d = {0: "creeping", 1: "crawling", 2: "stooping", 3: "climbing", 4: "other"} | |
| return d[l] | |
| def label_to_text_ja(l): | |
| d = {0: "ใใฎใณใใใงใใ", 1: "้ใฃใฆใใ", 2: "ใใใใงใใ", 3: "ใใ็ปใฃใฆใใ", 4: "ใใฎไป"} | |
| return d[l] | |
| def show_bb(img, x, y, w, h, text, textcolor, bbcolor): | |
| draw = ImageDraw.Draw(img) | |
| text_w, text_h = draw.textsize(text) | |
| label_y = y if y <= text_h else y - text_h | |
| draw.rectangle((x, label_y, x+w, label_y+h), outline=bbcolor) | |
| draw.rectangle((x, label_y, x+text_w, label_y+text_h), outline=bbcolor, fill=bbcolor) | |
| draw.text((x, label_y), text, fill=textcolor) | |
| def postprocess(true_image, o): | |
| copy_im = true_image.copy() | |
| data = o[0] | |
| boxes = data["boxes"] | |
| labels = data["labels"].tolist() | |
| scores = data["scores"].tolist() | |
| selected_labels = [] | |
| selected_scores = [] | |
| selected_indices = [] | |
| thresh = 0.30 | |
| for i, box in enumerate(boxes.tolist()): | |
| # if scores[i] > thresh: | |
| if i == scores.index(max(scores)): | |
| show_bb(copy_im, box[0],box[1],box[2],box[3], label_to_text_en(labels[i]) , (255, 255, 255), (255, 0, 0)) #xywh | |
| selected_labels.append(label_to_text_ja(labels[i])) | |
| selected_scores.append( '{:.3f}'.format(scores[i])) | |
| selected_indices.append(i) | |
| copy_im.show() | |
| copy_im.save("img/detected.png") | |
| return selected_labels, selected_scores, selected_indices | |
| def inference(image_pil): | |
| num_classes = 5 | |
| backbone = resnet_fpn_backbone('resnet18', False) | |
| model = FasterRCNN(backbone, num_classes) | |
| state_dict = torch.load('model/model/densenet-model-9-mAp--1.0.pth',map_location=device) | |
| model.load_state_dict(state_dict["model"]) | |
| model.eval() | |
| _transform = T.Compose([T.ToTensor()]) | |
| image = image_pil.convert("RGB") | |
| image = _transform(image) | |
| with torch.no_grad(): | |
| output = model([image]) | |
| res = postprocess(image_pil, output) | |
| return output, res | |