| | |
| |
|
| | import torch |
| |
|
| | from ultralytics.yolo.engine.predictor import BasePredictor |
| | from ultralytics.yolo.engine.results import Results |
| | from ultralytics.yolo.utils import ops |
| | from ultralytics.yolo.utils.ops import xyxy2xywh |
| |
|
| |
|
| | class NASPredictor(BasePredictor): |
| |
|
| | def postprocess(self, preds_in, img, orig_imgs): |
| | """Postprocesses predictions and returns a list of Results objects.""" |
| |
|
| | |
| | boxes = xyxy2xywh(preds_in[0][0]) |
| | preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) |
| |
|
| | preds = ops.non_max_suppression(preds, |
| | self.args.conf, |
| | self.args.iou, |
| | agnostic=self.args.agnostic_nms, |
| | max_det=self.args.max_det, |
| | classes=self.args.classes) |
| |
|
| | results = [] |
| | for i, pred in enumerate(preds): |
| | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs |
| | if not isinstance(orig_imgs, torch.Tensor): |
| | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) |
| | path = self.batch[0] |
| | img_path = path[i] if isinstance(path, list) else path |
| | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred)) |
| | return results |
| |
|