Spaces:
Build error
Build error
| import torch | |
| from ultralytics import YOLO | |
| class PearDetectionModel: | |
| def __init__(self, config) -> None: | |
| self.device = ( | |
| torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| ) | |
| self.model = YOLO(config["model_path"], task="detect") | |
| self.names = config["classes"] | |
| def detect(self, img): | |
| results = self.model.predict(img) | |
| return results[0].boxes.cpu().numpy() | |
| def inference(self, img): | |
| pred = self.detect(img) | |
| # remove the box with confidence lower than 0.9 if no "burn_bbox" is detected, else 0.8 | |
| pred = ( | |
| pred[pred.conf > 0.8] | |
| if all([pred != "burn_bbox" for pred in self.names]) | |
| else pred[pred.conf > 0.5] | |
| ) | |
| labels = [self.names[int(cat)] for cat in pred.cls] | |
| # if any classes rather than "normal_pear_box" is detected, return 0 else return 1 | |
| if any([label == "burn_bbox" for label in labels]): | |
| return 1, pred.xyxy, pred.conf | |
| else: | |
| return 0, pred.xyxy, pred.conf | |
| def _preporcess(self, img): | |
| pass |