Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import torch | |
| from fastai.vision.all import load_learner | |
| from huggingface_hub import hf_hub_download | |
| from constants import (CLAS_FILENAME, CLAS_FILEPATH, CLAS_THRESHOLD, | |
| DET_FILENAME, DET_FILEPATH, DET_NAME, DET_THRESHOLD, | |
| DEVICE, HF_CLAS_REPO_NAME, HF_DET_REPO_NAME, | |
| MODELS_PATH) | |
| from efficientdet.efficientdet import get_transforms, rescale_bboxes, set_model | |
| def localize_trash(im): | |
| # detector, if checkpoint doesn't exist then download from hf | |
| if not os.path.exists(DET_FILEPATH): | |
| hf_hub_download(HF_DET_REPO_NAME, DET_FILENAME, local_dir=MODELS_PATH) | |
| detector = set_model(DET_NAME, 1, DET_FILEPATH, DEVICE) | |
| detector.eval() | |
| # mean-std normalize the input image (batch-size: 1) | |
| img = get_transforms(im) | |
| # propagate through the model | |
| outputs = detector(img.to(DEVICE)) | |
| # keep only predictions above set confidence | |
| bboxes_keep = outputs[0, outputs[0, :, 4] > DET_THRESHOLD] | |
| probas = bboxes_keep[:, 4:] | |
| # convert boxes to image scales | |
| bboxes_scaled = rescale_bboxes(bboxes_keep[:, :4], im.size, tuple(img.size()[2:])) | |
| return probas, bboxes_scaled | |
| def classify_trash(im, probas, bboxes_scaled): | |
| # classifier, if checkpoint doesn't exist then download from hf | |
| if not os.path.exists(CLAS_FILEPATH): | |
| hf_hub_download(HF_CLAS_REPO_NAME, CLAS_FILENAME, local_dir=MODELS_PATH) | |
| classifier = load_learner(CLAS_FILEPATH) | |
| bboxes_final = [] | |
| cls_prob = [] | |
| for p, (xmin, ymin, xmax, ymax) in zip(probas, bboxes_scaled.tolist()): | |
| img = im.crop((xmin, ymin, xmax, ymax)) | |
| outputs = classifier.predict(img) | |
| p[1] = torch.topk(outputs[2], k=1).indices.squeeze(0).item() | |
| p[0] = torch.max(np.trunc(outputs[2] * 100)) | |
| if p[0] >= CLAS_THRESHOLD * 100: | |
| bboxes_final.append((xmin, ymin, xmax, ymax)) | |
| cls_prob.append(p) | |
| return cls_prob, bboxes_final | |
| def detect_trash(img): | |
| # prepare models for evaluation | |
| torch.set_grad_enabled(False) | |
| # 1) Localize | |
| probas, bboxes_scaled = localize_trash(img) | |
| # 2) Classify | |
| cls_prob, bboxes_final = classify_trash(img, probas, bboxes_scaled) | |
| return cls_prob, bboxes_final | |