| | |
| |
|
| | import ast |
| | import contextlib |
| | import json |
| | import platform |
| | import zipfile |
| | from collections import OrderedDict, namedtuple |
| | from pathlib import Path |
| | from urllib.parse import urlparse |
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from PIL import Image |
| |
|
| | from ultralytics.yolo.utils import LINUX, LOGGER, ROOT, yaml_load |
| | from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version, check_yaml |
| | from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url |
| | from ultralytics.yolo.utils.ops import xywh2xyxy |
| |
|
| |
|
| | def check_class_names(names): |
| | """Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts.""" |
| | if isinstance(names, list): |
| | names = dict(enumerate(names)) |
| | if isinstance(names, dict): |
| | |
| | names = {int(k): str(v) for k, v in names.items()} |
| | n = len(names) |
| | if max(names.keys()) >= n: |
| | raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices ' |
| | f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.') |
| | if isinstance(names[0], str) and names[0].startswith('n0'): |
| | map = yaml_load(ROOT / 'datasets/ImageNet.yaml')['map'] |
| | names = {k: map[v] for k, v in names.items()} |
| | return names |
| |
|
| |
|
| | class AutoBackend(nn.Module): |
| |
|
| | def __init__(self, |
| | weights='yolov8n.pt', |
| | device=torch.device('cpu'), |
| | dnn=False, |
| | data=None, |
| | fp16=False, |
| | fuse=True, |
| | verbose=True): |
| | """ |
| | MultiBackend class for python inference on various platforms using Ultralytics YOLO. |
| | |
| | Args: |
| | weights (str): The path to the weights file. Default: 'yolov8n.pt' |
| | device (torch.device): The device to run the model on. |
| | dnn (bool): Use OpenCV DNN module for inference if True, defaults to False. |
| | data (str | Path | optional): Additional data.yaml file for class names. |
| | fp16 (bool): If True, use half precision. Default: False |
| | fuse (bool): Whether to fuse the model or not. Default: True |
| | verbose (bool): Whether to run in verbose mode or not. Default: True |
| | |
| | Supported formats and their naming conventions: |
| | | Format | Suffix | |
| | |-----------------------|------------------| |
| | | PyTorch | *.pt | |
| | | TorchScript | *.torchscript | |
| | | ONNX Runtime | *.onnx | |
| | | ONNX OpenCV DNN | *.onnx dnn=True | |
| | | OpenVINO | *.xml | |
| | | CoreML | *.mlmodel | |
| | | TensorRT | *.engine | |
| | | TensorFlow SavedModel | *_saved_model | |
| | | TensorFlow GraphDef | *.pb | |
| | | TensorFlow Lite | *.tflite | |
| | | TensorFlow Edge TPU | *_edgetpu.tflite | |
| | | PaddlePaddle | *_paddle_model | |
| | """ |
| | super().__init__() |
| | w = str(weights[0] if isinstance(weights, list) else weights) |
| | nn_module = isinstance(weights, torch.nn.Module) |
| | pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w) |
| | fp16 &= pt or jit or onnx or engine or nn_module or triton |
| | nhwc = coreml or saved_model or pb or tflite or edgetpu |
| | stride = 32 |
| | model, metadata = None, None |
| | cuda = torch.cuda.is_available() and device.type != 'cpu' |
| | if not (pt or triton or nn_module): |
| | w = attempt_download_asset(w) |
| |
|
| | |
| | if nn_module: |
| | model = weights.to(device) |
| | model = model.fuse(verbose=verbose) if fuse else model |
| | if hasattr(model, 'kpt_shape'): |
| | kpt_shape = model.kpt_shape |
| | stride = max(int(model.stride.max()), 32) |
| | names = model.module.names if hasattr(model, 'module') else model.names |
| | model.half() if fp16 else model.float() |
| | self.model = model |
| | pt = True |
| | elif pt: |
| | from ultralytics.nn.tasks import attempt_load_weights |
| | model = attempt_load_weights(weights if isinstance(weights, list) else w, |
| | device=device, |
| | inplace=True, |
| | fuse=fuse) |
| | if hasattr(model, 'kpt_shape'): |
| | kpt_shape = model.kpt_shape |
| | stride = max(int(model.stride.max()), 32) |
| | names = model.module.names if hasattr(model, 'module') else model.names |
| | model.half() if fp16 else model.float() |
| | self.model = model |
| | elif jit: |
| | LOGGER.info(f'Loading {w} for TorchScript inference...') |
| | extra_files = {'config.txt': ''} |
| | model = torch.jit.load(w, _extra_files=extra_files, map_location=device) |
| | model.half() if fp16 else model.float() |
| | if extra_files['config.txt']: |
| | metadata = json.loads(extra_files['config.txt'], object_hook=lambda x: dict(x.items())) |
| | elif dnn: |
| | LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...') |
| | check_requirements('opencv-python>=4.5.4') |
| | net = cv2.dnn.readNetFromONNX(w) |
| | elif onnx: |
| | LOGGER.info(f'Loading {w} for ONNX Runtime inference...') |
| | check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime')) |
| | import onnxruntime |
| | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider'] |
| | session = onnxruntime.InferenceSession(w, providers=providers) |
| | output_names = [x.name for x in session.get_outputs()] |
| | metadata = session.get_modelmeta().custom_metadata_map |
| | elif xml: |
| | LOGGER.info(f'Loading {w} for OpenVINO inference...') |
| | check_requirements('openvino') |
| | from openvino.runtime import Core, Layout, get_batch |
| | ie = Core() |
| | w = Path(w) |
| | if not w.is_file(): |
| | w = next(w.glob('*.xml')) |
| | network = ie.read_model(model=str(w), weights=w.with_suffix('.bin')) |
| | if network.get_parameters()[0].get_layout().empty: |
| | network.get_parameters()[0].set_layout(Layout('NCHW')) |
| | batch_dim = get_batch(network) |
| | if batch_dim.is_static: |
| | batch_size = batch_dim.get_length() |
| | executable_network = ie.compile_model(network, device_name='CPU') |
| | metadata = w.parent / 'metadata.yaml' |
| | elif engine: |
| | LOGGER.info(f'Loading {w} for TensorRT inference...') |
| | try: |
| | import tensorrt as trt |
| | except ImportError: |
| | if LINUX: |
| | check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com') |
| | import tensorrt as trt |
| | check_version(trt.__version__, '7.0.0', hard=True) |
| | if device.type == 'cpu': |
| | device = torch.device('cuda:0') |
| | Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) |
| | logger = trt.Logger(trt.Logger.INFO) |
| | |
| | with open(w, 'rb') as f, trt.Runtime(logger) as runtime: |
| | meta_len = int.from_bytes(f.read(4), byteorder='little') |
| | metadata = json.loads(f.read(meta_len).decode('utf-8')) |
| | model = runtime.deserialize_cuda_engine(f.read()) |
| | context = model.create_execution_context() |
| | bindings = OrderedDict() |
| | output_names = [] |
| | fp16 = False |
| | dynamic = False |
| | for i in range(model.num_bindings): |
| | name = model.get_binding_name(i) |
| | dtype = trt.nptype(model.get_binding_dtype(i)) |
| | if model.binding_is_input(i): |
| | if -1 in tuple(model.get_binding_shape(i)): |
| | dynamic = True |
| | context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2])) |
| | if dtype == np.float16: |
| | fp16 = True |
| | else: |
| | output_names.append(name) |
| | shape = tuple(context.get_binding_shape(i)) |
| | im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) |
| | bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) |
| | binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) |
| | batch_size = bindings['images'].shape[0] |
| | elif coreml: |
| | LOGGER.info(f'Loading {w} for CoreML inference...') |
| | import coremltools as ct |
| | model = ct.models.MLModel(w) |
| | metadata = dict(model.user_defined_metadata) |
| | elif saved_model: |
| | LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...') |
| | import tensorflow as tf |
| | keras = False |
| | model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) |
| | metadata = Path(w) / 'metadata.yaml' |
| | elif pb: |
| | LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...') |
| | import tensorflow as tf |
| |
|
| | from ultralytics.yolo.engine.exporter import gd_outputs |
| |
|
| | def wrap_frozen_graph(gd, inputs, outputs): |
| | """Wrap frozen graphs for deployment.""" |
| | x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) |
| | ge = x.graph.as_graph_element |
| | return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) |
| |
|
| | gd = tf.Graph().as_graph_def() |
| | with open(w, 'rb') as f: |
| | gd.ParseFromString(f.read()) |
| | frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd)) |
| | elif tflite or edgetpu: |
| | try: |
| | from tflite_runtime.interpreter import Interpreter, load_delegate |
| | except ImportError: |
| | import tensorflow as tf |
| | Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate |
| | if edgetpu: |
| | LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...') |
| | delegate = { |
| | 'Linux': 'libedgetpu.so.1', |
| | 'Darwin': 'libedgetpu.1.dylib', |
| | 'Windows': 'edgetpu.dll'}[platform.system()] |
| | interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)]) |
| | else: |
| | LOGGER.info(f'Loading {w} for TensorFlow Lite inference...') |
| | interpreter = Interpreter(model_path=w) |
| | interpreter.allocate_tensors() |
| | input_details = interpreter.get_input_details() |
| | output_details = interpreter.get_output_details() |
| | |
| | with contextlib.suppress(zipfile.BadZipFile): |
| | with zipfile.ZipFile(w, 'r') as model: |
| | meta_file = model.namelist()[0] |
| | metadata = ast.literal_eval(model.read(meta_file).decode('utf-8')) |
| | elif tfjs: |
| | raise NotImplementedError('YOLOv8 TF.js inference is not supported') |
| | elif paddle: |
| | LOGGER.info(f'Loading {w} for PaddlePaddle inference...') |
| | check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle') |
| | import paddle.inference as pdi |
| | w = Path(w) |
| | if not w.is_file(): |
| | w = next(w.rglob('*.pdmodel')) |
| | config = pdi.Config(str(w), str(w.with_suffix('.pdiparams'))) |
| | if cuda: |
| | config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0) |
| | predictor = pdi.create_predictor(config) |
| | input_handle = predictor.get_input_handle(predictor.get_input_names()[0]) |
| | output_names = predictor.get_output_names() |
| | metadata = w.parents[1] / 'metadata.yaml' |
| | elif triton: |
| | LOGGER.info('Triton Inference Server not supported...') |
| | ''' |
| | TODO: |
| | check_requirements('tritonclient[all]') |
| | from utils.triton import TritonRemoteModel |
| | model = TritonRemoteModel(url=w) |
| | nhwc = model.runtime.startswith("tensorflow") |
| | ''' |
| | else: |
| | from ultralytics.yolo.engine.exporter import export_formats |
| | raise TypeError(f"model='{w}' is not a supported model format. " |
| | 'See https://docs.ultralytics.com/modes/predict for help.' |
| | f'\n\n{export_formats()}') |
| |
|
| | |
| | if isinstance(metadata, (str, Path)) and Path(metadata).exists(): |
| | metadata = yaml_load(metadata) |
| | if metadata: |
| | for k, v in metadata.items(): |
| | if k in ('stride', 'batch'): |
| | metadata[k] = int(v) |
| | elif k in ('imgsz', 'names', 'kpt_shape') and isinstance(v, str): |
| | metadata[k] = eval(v) |
| | stride = metadata['stride'] |
| | task = metadata['task'] |
| | batch = metadata['batch'] |
| | imgsz = metadata['imgsz'] |
| | names = metadata['names'] |
| | kpt_shape = metadata.get('kpt_shape') |
| | elif not (pt or triton or nn_module): |
| | LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'") |
| |
|
| | |
| | if 'names' not in locals(): |
| | names = self._apply_default_class_names(data) |
| | names = check_class_names(names) |
| |
|
| | self.__dict__.update(locals()) |
| |
|
| | def forward(self, im, augment=False, visualize=False): |
| | """ |
| | Runs inference on the YOLOv8 MultiBackend model. |
| | |
| | Args: |
| | im (torch.Tensor): The image tensor to perform inference on. |
| | augment (bool): whether to perform data augmentation during inference, defaults to False |
| | visualize (bool): whether to visualize the output predictions, defaults to False |
| | |
| | Returns: |
| | (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True) |
| | """ |
| | b, ch, h, w = im.shape |
| | if self.fp16 and im.dtype != torch.float16: |
| | im = im.half() |
| | if self.nhwc: |
| | im = im.permute(0, 2, 3, 1) |
| |
|
| | if self.pt or self.nn_module: |
| | y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im) |
| | elif self.jit: |
| | y = self.model(im) |
| | elif self.dnn: |
| | im = im.cpu().numpy() |
| | self.net.setInput(im) |
| | y = self.net.forward() |
| | elif self.onnx: |
| | im = im.cpu().numpy() |
| | y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) |
| | elif self.xml: |
| | im = im.cpu().numpy() |
| | y = list(self.executable_network([im]).values()) |
| | elif self.engine: |
| | if self.dynamic and im.shape != self.bindings['images'].shape: |
| | i = self.model.get_binding_index('images') |
| | self.context.set_binding_shape(i, im.shape) |
| | self.bindings['images'] = self.bindings['images']._replace(shape=im.shape) |
| | for name in self.output_names: |
| | i = self.model.get_binding_index(name) |
| | self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i))) |
| | s = self.bindings['images'].shape |
| | assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" |
| | self.binding_addrs['images'] = int(im.data_ptr()) |
| | self.context.execute_v2(list(self.binding_addrs.values())) |
| | y = [self.bindings[x].data for x in sorted(self.output_names)] |
| | elif self.coreml: |
| | im = im[0].cpu().numpy() |
| | im_pil = Image.fromarray((im * 255).astype('uint8')) |
| | |
| | y = self.model.predict({'image': im_pil}) |
| | if 'confidence' in y: |
| | box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) |
| | conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float) |
| | y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1) |
| | elif len(y) == 1: |
| | y = list(y.values()) |
| | elif len(y) == 2: |
| | y = list(reversed(y.values())) |
| | elif self.paddle: |
| | im = im.cpu().numpy().astype(np.float32) |
| | self.input_handle.copy_from_cpu(im) |
| | self.predictor.run() |
| | y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names] |
| | elif self.triton: |
| | y = self.model(im) |
| | else: |
| | im = im.cpu().numpy() |
| | if self.saved_model: |
| | y = self.model(im, training=False) if self.keras else self.model(im) |
| | if not isinstance(y, list): |
| | y = [y] |
| | elif self.pb: |
| | y = self.frozen_func(x=self.tf.constant(im)) |
| | if len(y) == 2 and len(self.names) == 999: |
| | ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) |
| | nc = y[ib].shape[1] - y[ip].shape[3] - 4 |
| | self.names = {i: f'class{i}' for i in range(nc)} |
| | else: |
| | input = self.input_details[0] |
| | int8 = input['dtype'] == np.int8 |
| | if int8: |
| | scale, zero_point = input['quantization'] |
| | im = (im / scale + zero_point).astype(np.int8) |
| | self.interpreter.set_tensor(input['index'], im) |
| | self.interpreter.invoke() |
| | y = [] |
| | for output in self.output_details: |
| | x = self.interpreter.get_tensor(output['index']) |
| | if int8: |
| | scale, zero_point = output['quantization'] |
| | x = (x.astype(np.float32) - zero_point) * scale |
| | y.append(x) |
| | |
| | if len(y) == 2: |
| | if len(y[1].shape) != 4: |
| | y = list(reversed(y)) |
| | y[1] = np.transpose(y[1], (0, 3, 1, 2)) |
| | y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y] |
| | |
| |
|
| | |
| | |
| | if isinstance(y, (list, tuple)): |
| | return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y] |
| | else: |
| | return self.from_numpy(y) |
| |
|
| | def from_numpy(self, x): |
| | """ |
| | Convert a numpy array to a tensor. |
| | |
| | Args: |
| | x (np.ndarray): The array to be converted. |
| | |
| | Returns: |
| | (torch.Tensor): The converted tensor |
| | """ |
| | return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x |
| |
|
| | def warmup(self, imgsz=(1, 3, 640, 640)): |
| | """ |
| | Warm up the model by running one forward pass with a dummy input. |
| | |
| | Args: |
| | imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width) |
| | |
| | Returns: |
| | (None): This method runs the forward pass and don't return any value |
| | """ |
| | warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module |
| | if any(warmup_types) and (self.device.type != 'cpu' or self.triton): |
| | im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) |
| | for _ in range(2 if self.jit else 1): |
| | self.forward(im) |
| |
|
| | @staticmethod |
| | def _apply_default_class_names(data): |
| | """Applies default class names to an input YAML file or returns numerical class names.""" |
| | with contextlib.suppress(Exception): |
| | return yaml_load(check_yaml(data))['names'] |
| | return {i: f'class{i}' for i in range(999)} |
| |
|
| | @staticmethod |
| | def _model_type(p='path/to/model.pt'): |
| | """ |
| | This function takes a path to a model file and returns the model type |
| | |
| | Args: |
| | p: path to the model file. Defaults to path/to/model.pt |
| | """ |
| | |
| | |
| | from ultralytics.yolo.engine.exporter import export_formats |
| | sf = list(export_formats().Suffix) |
| | if not is_url(p, check=False) and not isinstance(p, str): |
| | check_suffix(p, sf) |
| | url = urlparse(p) |
| | types = [s in Path(p).name for s in sf] |
| | types[8] &= not types[9] |
| | triton = not any(types) and all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc]) |
| | return types + [triton] |
| |
|